"""Tests for source accuracy tracker — SourceAccuracy dataclass and database functions.""" from __future__ import annotations from datetime import datetime, timezone from unittest.mock import AsyncMock import pytest from services.aggregation.source_accuracy import ( SourceAccuracy, fetch_source_accuracy, update_source_accuracy, ) # --------------------------------------------------------------------------- # SourceAccuracy.accuracy_factor property # --------------------------------------------------------------------------- def test_accuracy_factor_low_sample_count(): """When sample_count < 10, accuracy_factor returns neutral 1.0.""" sa = SourceAccuracy( source_id="src-1", accuracy_ratio=0.9, sample_count=5, last_updated=datetime.now(timezone.utc), ) assert sa.accuracy_factor == 1.0 def test_accuracy_factor_exactly_ten_samples(): """When sample_count == 10, accuracy_factor uses the formula.""" sa = SourceAccuracy( source_id="src-1", accuracy_ratio=0.8, sample_count=10, last_updated=datetime.now(timezone.utc), ) assert abs(sa.accuracy_factor - 1.3) < 1e-9 def test_accuracy_factor_zero_accuracy(): """0% accuracy with enough samples gives factor 0.5.""" sa = SourceAccuracy( source_id="src-1", accuracy_ratio=0.0, sample_count=100, last_updated=datetime.now(timezone.utc), ) assert abs(sa.accuracy_factor - 0.5) < 1e-9 def test_accuracy_factor_full_accuracy(): """100% accuracy with enough samples gives factor 1.5.""" sa = SourceAccuracy( source_id="src-1", accuracy_ratio=1.0, sample_count=100, last_updated=datetime.now(timezone.utc), ) assert abs(sa.accuracy_factor - 1.5) < 1e-9 def test_accuracy_factor_clamps_corrupted_high(): """Corrupted accuracy_ratio > 1.0 is clamped to 1.0 in the factor.""" sa = SourceAccuracy( source_id="src-1", accuracy_ratio=2.5, sample_count=50, last_updated=datetime.now(timezone.utc), ) # clamped to 1.0 → factor = 0.5 + 1.0 = 1.5 assert abs(sa.accuracy_factor - 1.5) < 1e-9 def test_accuracy_factor_clamps_corrupted_negative(): """Corrupted accuracy_ratio < 0.0 is clamped to 0.0 in the factor.""" sa = SourceAccuracy( source_id="src-1", accuracy_ratio=-0.3, sample_count=50, last_updated=datetime.now(timezone.utc), ) # clamped to 0.0 → factor = 0.5 + 0.0 = 0.5 assert abs(sa.accuracy_factor - 0.5) < 1e-9 def test_accuracy_factor_nine_samples_neutral(): """sample_count=9 is still below threshold, returns 1.0.""" sa = SourceAccuracy( source_id="src-1", accuracy_ratio=0.0, sample_count=9, last_updated=datetime.now(timezone.utc), ) assert sa.accuracy_factor == 1.0 # --------------------------------------------------------------------------- # fetch_source_accuracy # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_fetch_source_accuracy_empty_ids(): """Empty source_ids list returns empty dict without querying.""" pool = AsyncMock() result = await fetch_source_accuracy(pool, []) assert result == {} pool.fetch.assert_not_called() @pytest.mark.asyncio async def test_fetch_source_accuracy_returns_records(): """Successful fetch returns SourceAccuracy records keyed by source_id.""" now = datetime.now(timezone.utc) pool = AsyncMock() pool.fetch = AsyncMock(return_value=[ { "source_id": "src-a", "accuracy_ratio": 0.75, "sample_count": 20, "last_updated": now, }, { "source_id": "src-b", "accuracy_ratio": 0.4, "sample_count": 15, "last_updated": now, }, ]) result = await fetch_source_accuracy(pool, ["src-a", "src-b"]) assert len(result) == 2 assert result["src-a"].accuracy_ratio == 0.75 assert result["src-a"].sample_count == 20 assert result["src-b"].accuracy_ratio == 0.4 @pytest.mark.asyncio async def test_fetch_source_accuracy_clamps_corrupted(): """Corrupted accuracy_ratio values are clamped to [0.0, 1.0].""" now = datetime.now(timezone.utc) pool = AsyncMock() pool.fetch = AsyncMock(return_value=[ { "source_id": "src-bad", "accuracy_ratio": 1.5, "sample_count": 30, "last_updated": now, }, ]) result = await fetch_source_accuracy(pool, ["src-bad"]) assert result["src-bad"].accuracy_ratio == 1.0 @pytest.mark.asyncio async def test_fetch_source_accuracy_db_error_returns_empty(): """When the database is unreachable, returns empty dict.""" pool = AsyncMock() pool.fetch = AsyncMock(side_effect=Exception("connection refused")) result = await fetch_source_accuracy(pool, ["src-a"]) assert result == {} # --------------------------------------------------------------------------- # update_source_accuracy # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_update_source_accuracy_empty_outcomes(): """Empty outcomes list does nothing.""" pool = AsyncMock() await update_source_accuracy(pool, "src-1", []) pool.execute.assert_not_called() @pytest.mark.asyncio async def test_update_source_accuracy_counts_correctly(): """Correct and incorrect predictions are counted properly.""" pool = AsyncMock() pool.execute = AsyncMock() outcomes = [ ("bullish", 0.05), # correct ("bullish", -0.02), # incorrect ("bearish", -0.03), # correct ("bearish", 0.01), # incorrect ] await update_source_accuracy(pool, "src-1", outcomes) pool.execute.assert_called_once() call_args = pool.execute.call_args # accuracy_ratio = 2/4 = 0.5, total = 4 assert abs(call_args[0][2] - 0.5) < 1e-9 # accuracy_ratio assert call_args[0][3] == 4 # total @pytest.mark.asyncio async def test_update_source_accuracy_skips_neutral(): """Neutral predictions and zero returns are excluded.""" pool = AsyncMock() pool.execute = AsyncMock() outcomes = [ ("neutral", 0.05), # skipped — neutral direction ("bullish", 0.0), # skipped — zero return ("bullish", 0.03), # counted — correct ] await update_source_accuracy(pool, "src-1", outcomes) pool.execute.assert_called_once() call_args = pool.execute.call_args # accuracy_ratio = 1/1 = 1.0, total = 1 assert abs(call_args[0][2] - 1.0) < 1e-9 assert call_args[0][3] == 1 @pytest.mark.asyncio async def test_update_source_accuracy_all_neutral_skips(): """When all outcomes are neutral/zero, no DB call is made.""" pool = AsyncMock() await update_source_accuracy(pool, "src-1", [("neutral", 0.05)]) pool.execute.assert_not_called() @pytest.mark.asyncio async def test_update_source_accuracy_db_error_logs_and_continues(): """DB errors are logged but do not raise.""" pool = AsyncMock() pool.execute = AsyncMock(side_effect=Exception("connection refused")) # Should not raise await update_source_accuracy(pool, "src-1", [("bullish", 0.05)])