"""Unit tests for scheduler pure functions and orchestration. Covers: get_cadence_for_source, compute_backoff, is_source_due, build_job_payload, schedule_cycle (mocked DB/Redis), check_rate_limit, recover_stale_documents, retry_failed_extractions, and error handling for DB/Redis connection failures. Requirements: 1.1, 1.2, 1.3, 1.4 """ from __future__ import annotations import json import uuid from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock import pytest from services.scheduler.app import ( DEFAULT_BACKOFF_BASE, DEFAULT_CADENCES, DEFAULT_RATE_LIMITS, MAX_BACKOFF, MAX_RETRY_COUNT, POLYGON_GLOBAL_RATE_LIMIT, build_job_payload, check_rate_limit, compute_backoff, get_cadence_for_source, is_source_due, recover_stale_documents, retry_failed_extractions, schedule_cycle, ) from services.shared.redis_keys import ( QUEUE_EXTRACTION, QUEUE_MACRO_CLASSIFICATION, queue_key, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _now() -> datetime: return datetime(2026, 6, 15, 12, 0, 0, tzinfo=timezone.utc) def _make_source( source_id: str = "src-1", company_id: str = "cid-1", ticker: str = "AAPL", source_type: str = "news_api", source_name: str = "NewsAPI", config: dict | None = None, credibility_score: float = 0.8, legal_name: str = "Apple Inc.", ) -> dict: """Build a dict that mimics an asyncpg.Record for a source row.""" return { "source_id": source_id, "company_id": company_id, "ticker": ticker, "legal_name": legal_name, "source_type": source_type, "source_name": source_name, "config": config, "credibility_score": credibility_score, } def _make_last_run( status: str = "completed", started_at: datetime | None = None, completed_at: datetime | None = None, retry_count: int = 0, next_retry_at: datetime | None = None, ) -> dict: """Build a dict that mimics an asyncpg.Record for an ingestion_runs row.""" return { "status": status, "started_at": started_at or _now() - timedelta(seconds=600), "completed_at": completed_at or _now() - timedelta(seconds=600), "retry_count": retry_count, "next_retry_at": next_retry_at, } def _mock_pool() -> AsyncMock: """Create a mock asyncpg.Pool with standard async methods.""" pool = AsyncMock() pool.fetch = AsyncMock(return_value=[]) pool.fetchrow = AsyncMock(return_value=None) pool.fetchval = AsyncMock(return_value=None) pool.execute = AsyncMock(return_value="UPDATE 0") return pool def _mock_redis() -> AsyncMock: """Create a mock redis.asyncio.Redis with standard async methods.""" rds = AsyncMock() rds.rpush = AsyncMock(return_value=1) rds.set = AsyncMock(return_value=True) rds.get = AsyncMock(return_value=None) rds.incr = AsyncMock(return_value=1) rds.expire = AsyncMock(return_value=True) rds.decr = AsyncMock(return_value=0) rds.delete = AsyncMock(return_value=1) return rds # --------------------------------------------------------------------------- # get_cadence_for_source # --------------------------------------------------------------------------- class TestGetCadenceForSource: def test_returns_default_for_known_type(self): assert get_cadence_for_source("market_api", None) == DEFAULT_CADENCES["market_api"] def test_returns_fallback_for_unknown_type(self): assert get_cadence_for_source("unknown_type", None) == 600 def test_config_override(self): assert get_cadence_for_source("market_api", {"polling_interval_seconds": 120}) == 120 def test_config_override_clamped_to_minimum(self): assert get_cadence_for_source("market_api", {"polling_interval_seconds": 3}) == 10 def test_invalid_config_value_falls_back(self): assert get_cadence_for_source("news_api", {"polling_interval_seconds": "bad"}) == DEFAULT_CADENCES["news_api"] # --------------------------------------------------------------------------- # compute_backoff # --------------------------------------------------------------------------- class TestComputeBackoff: def test_zero_retries(self): assert compute_backoff(0) == DEFAULT_BACKOFF_BASE def test_exponential_growth(self): assert compute_backoff(1) == DEFAULT_BACKOFF_BASE * 2 assert compute_backoff(2) == DEFAULT_BACKOFF_BASE * 4 def test_capped_at_max(self): assert compute_backoff(20) == MAX_BACKOFF def test_exponent_capped_at_8(self): # 2^8 = 256, so 60 * 256 = 15360 > MAX_BACKOFF (3600) assert compute_backoff(8) == MAX_BACKOFF # --------------------------------------------------------------------------- # is_source_due # --------------------------------------------------------------------------- class TestIsSourceDue: def test_never_run_is_due(self): assert is_source_due("market_api", None, None, None, 0, None, _now()) def test_completed_within_cadence_not_due(self): last = _now() - timedelta(seconds=100) assert not is_source_due("market_api", None, last, "completed", 0, None, _now()) def test_completed_past_cadence_is_due(self): last = _now() - timedelta(seconds=400) assert is_source_due("market_api", None, last, "completed", 0, None, _now()) def test_running_not_due(self): last = _now() - timedelta(seconds=5) assert not is_source_due("market_api", None, last, "running", 0, None, _now()) def test_failed_within_backoff_not_due(self): last = _now() - timedelta(seconds=30) next_retry = _now() + timedelta(seconds=30) assert not is_source_due("market_api", None, last, "failed", 1, next_retry, _now()) def test_failed_past_backoff_is_due(self): last = _now() - timedelta(seconds=120) next_retry = _now() - timedelta(seconds=10) assert is_source_due("market_api", None, last, "failed", 1, next_retry, _now()) def test_failed_max_retries_not_due(self): last = _now() - timedelta(seconds=120) assert not is_source_due( "market_api", None, last, "failed", MAX_RETRY_COUNT, None, _now() ) def test_failed_no_next_retry_at_is_due(self): """Failed with retries remaining and no next_retry_at → allow retry.""" last = _now() - timedelta(seconds=120) assert is_source_due("market_api", None, last, "failed", 2, None, _now()) # --------------------------------------------------------------------------- # build_job_payload # --------------------------------------------------------------------------- class TestBuildJobPayload: def test_complete_payload(self): src = _make_source() now = _now() job = build_job_payload(src, ["Apple", "AAPL Inc"], now) assert job["source_id"] == "src-1" assert job["company_id"] == "cid-1" assert job["ticker"] == "AAPL" assert job["legal_name"] == "Apple Inc." assert job["aliases"] == ["Apple", "AAPL Inc"] assert job["source_type"] == "news_api" assert job["source_name"] == "NewsAPI" assert job["config"] == {} assert job["credibility_score"] == 0.8 assert job["scheduled_at"] == now.isoformat() def test_null_company_id(self): src = _make_source(company_id=None) src["company_id"] = None job = build_job_payload(src, [], _now()) assert job["company_id"] is None def test_null_credibility_defaults_to_half(self): src = _make_source(credibility_score=None) src["credibility_score"] = None job = build_job_payload(src, [], _now()) assert job["credibility_score"] == 0.5 # --------------------------------------------------------------------------- # check_rate_limit (async) # --------------------------------------------------------------------------- class TestCheckRateLimit: @pytest.mark.asyncio async def test_allowed_when_under_limit(self): rds = _mock_redis() rds.incr = AsyncMock(return_value=1) result = await check_rate_limit(rds, "news_api", _now()) assert result is True @pytest.mark.asyncio async def test_blocked_when_over_per_type_limit(self): rds = _mock_redis() limit = DEFAULT_RATE_LIMITS["news_api"] rds.incr = AsyncMock(return_value=limit + 1) result = await check_rate_limit(rds, "news_api", _now()) assert result is False @pytest.mark.asyncio async def test_polygon_global_limit_blocks(self): """market_api is a Polygon type — global limit should block even if per-type is OK.""" rds = _mock_redis() # Per-type counter is fine (1), but global counter exceeds limit call_count = 0 async def _incr_side_effect(key): nonlocal call_count call_count += 1 if call_count == 1: return 1 # per-type counter OK return POLYGON_GLOBAL_RATE_LIMIT + 1 # global counter exceeded rds.incr = AsyncMock(side_effect=_incr_side_effect) result = await check_rate_limit(rds, "market_api", _now()) assert result is False # Should have decremented the per-type counter rds.decr.assert_called_once() @pytest.mark.asyncio async def test_non_polygon_type_skips_global_check(self): """filings_api is not a Polygon type — no global limit check.""" rds = _mock_redis() rds.incr = AsyncMock(return_value=1) result = await check_rate_limit(rds, "filings_api", _now()) assert result is True # incr should be called only once (per-type), not twice (no global) assert rds.incr.call_count == 1 @pytest.mark.asyncio async def test_expire_set_on_first_increment(self): rds = _mock_redis() rds.incr = AsyncMock(return_value=1) await check_rate_limit(rds, "news_api", _now()) rds.expire.assert_called() @pytest.mark.asyncio async def test_custom_max_per_minute(self): rds = _mock_redis() rds.incr = AsyncMock(return_value=6) result = await check_rate_limit(rds, "news_api", _now(), max_per_minute=5) assert result is False # --------------------------------------------------------------------------- # schedule_cycle (mocked DB/Redis) # --------------------------------------------------------------------------- class TestScheduleCycle: @pytest.mark.asyncio async def test_enqueues_due_sources(self): pool = _mock_pool() rds = _mock_redis() src = _make_source() pool.fetch = AsyncMock(side_effect=[ [src], # fetch_active_sources [], # fetch_macro_sources [], # fetch_global_market_sources [], # fetch_aliases_for_company returns rows ]) # fetch_last_run returns None (never run → due) pool.fetchrow = AsyncMock(return_value=None) # Rate limit OK rds.incr = AsyncMock(return_value=1) enqueued = await schedule_cycle(pool, rds) assert enqueued == 1 rds.rpush.assert_called_once() # Verify the enqueued payload call_args = rds.rpush.call_args payload = json.loads(call_args[0][1]) assert payload["source_id"] == "src-1" assert payload["ticker"] == "AAPL" @pytest.mark.asyncio async def test_skips_not_due_sources(self): pool = _mock_pool() rds = _mock_redis() src = _make_source() pool.fetch = AsyncMock(side_effect=[ [src], # fetch_active_sources [], # fetch_macro_sources [], # fetch_global_market_sources ]) # Last run was recent → not due pool.fetchrow = AsyncMock(return_value=_make_last_run( status="completed", completed_at=datetime.now(tz=timezone.utc) - timedelta(seconds=10), )) enqueued = await schedule_cycle(pool, rds) assert enqueued == 0 rds.rpush.assert_not_called() @pytest.mark.asyncio async def test_skips_rate_limited_sources(self): pool = _mock_pool() rds = _mock_redis() src = _make_source() pool.fetch = AsyncMock(side_effect=[ [src], # fetch_active_sources [], # fetch_macro_sources [], # fetch_global_market_sources ]) pool.fetchrow = AsyncMock(return_value=None) # never run → due # Rate limit exceeded rds.incr = AsyncMock(return_value=DEFAULT_RATE_LIMITS["news_api"] + 1) enqueued = await schedule_cycle(pool, rds) assert enqueued == 0 @pytest.mark.asyncio async def test_enqueues_macro_sources(self): pool = _mock_pool() rds = _mock_redis() macro_src = _make_source( source_id="macro-1", company_id=None, ticker="", source_type="macro_news", source_name="MacroNewsSource", ) macro_src["company_id"] = None pool.fetch = AsyncMock(side_effect=[ [], # fetch_active_sources (empty) [macro_src], # fetch_macro_sources [], # fetch_global_market_sources ]) pool.fetchrow = AsyncMock(return_value=None) # never run → due rds.incr = AsyncMock(return_value=1) enqueued = await schedule_cycle(pool, rds) assert enqueued == 1 # --------------------------------------------------------------------------- # recover_stale_documents (mocked DB/Redis) # --------------------------------------------------------------------------- class TestRecoverStaleDocuments: @pytest.mark.asyncio async def test_recovers_stale_parsed_docs(self): pool = _mock_pool() rds = _mock_redis() doc_id = uuid.uuid4() pool.fetch = AsyncMock(return_value=[ {"id": doc_id, "document_type": "news", "ticker": "AAPL"}, ]) # _enqueue_if_new: rds.set returns True (new marker) rds.set = AsyncMock(return_value=True) count = await recover_stale_documents(pool, rds) assert count == 1 # Should push to extraction queue rds.rpush.assert_called_once() call_args = rds.rpush.call_args assert queue_key(QUEUE_EXTRACTION) in call_args[0][0] # Should update documents pool.execute.assert_called_once() @pytest.mark.asyncio async def test_routes_macro_event_to_classification_queue(self): pool = _mock_pool() rds = _mock_redis() doc_id = uuid.uuid4() pool.fetch = AsyncMock(return_value=[ {"id": doc_id, "document_type": "macro_event", "ticker": ""}, ]) rds.set = AsyncMock(return_value=True) count = await recover_stale_documents(pool, rds) assert count == 1 call_args = rds.rpush.call_args assert queue_key(QUEUE_MACRO_CLASSIFICATION) in call_args[0][0] @pytest.mark.asyncio async def test_skips_already_enqueued_docs(self): pool = _mock_pool() rds = _mock_redis() doc_id = uuid.uuid4() pool.fetch = AsyncMock(return_value=[ {"id": doc_id, "document_type": "news", "ticker": "AAPL"}, ]) # _enqueue_if_new: rds.set returns False (already tracked) rds.set = AsyncMock(return_value=False) count = await recover_stale_documents(pool, rds) assert count == 0 rds.rpush.assert_not_called() @pytest.mark.asyncio async def test_no_stale_docs_returns_zero(self): pool = _mock_pool() rds = _mock_redis() pool.fetch = AsyncMock(return_value=[]) count = await recover_stale_documents(pool, rds) assert count == 0 # --------------------------------------------------------------------------- # retry_failed_extractions (mocked DB/Redis) # --------------------------------------------------------------------------- class TestRetryFailedExtractions: @pytest.mark.asyncio async def test_retries_failed_docs(self): pool = _mock_pool() rds = _mock_redis() doc_id = uuid.uuid4() pool.fetch = AsyncMock(return_value=[ {"id": doc_id, "document_type": "filing", "ticker": "MSFT"}, ]) rds.set = AsyncMock(return_value=True) count = await retry_failed_extractions(pool, rds) assert count == 1 # Should push to extraction queue rds.rpush.assert_called_once() # Should delete failed intelligence rows and reset status assert pool.execute.call_count == 2 @pytest.mark.asyncio async def test_routes_macro_event_to_classification(self): pool = _mock_pool() rds = _mock_redis() doc_id = uuid.uuid4() pool.fetch = AsyncMock(return_value=[ {"id": doc_id, "document_type": "macro_event", "ticker": ""}, ]) rds.set = AsyncMock(return_value=True) count = await retry_failed_extractions(pool, rds) assert count == 1 call_args = rds.rpush.call_args assert queue_key(QUEUE_MACRO_CLASSIFICATION) in call_args[0][0] @pytest.mark.asyncio async def test_no_failed_docs_returns_zero(self): pool = _mock_pool() rds = _mock_redis() pool.fetch = AsyncMock(return_value=[]) count = await retry_failed_extractions(pool, rds) assert count == 0 # --------------------------------------------------------------------------- # Error handling: DB/Redis connection failures # --------------------------------------------------------------------------- class TestErrorHandling: @pytest.mark.asyncio async def test_schedule_cycle_handles_db_failure(self): """DB failure in fetch_active_sources should propagate but not crash the process.""" pool = _mock_pool() rds = _mock_redis() pool.fetch = AsyncMock(side_effect=Exception("connection refused")) with pytest.raises(Exception, match="connection refused"): await schedule_cycle(pool, rds) @pytest.mark.asyncio async def test_recover_stale_handles_db_failure(self): """DB failure in recover_stale_documents should propagate.""" pool = _mock_pool() rds = _mock_redis() pool.fetch = AsyncMock(side_effect=ConnectionError("pg pool exhausted")) with pytest.raises(ConnectionError, match="pg pool exhausted"): await recover_stale_documents(pool, rds) @pytest.mark.asyncio async def test_check_rate_limit_handles_redis_failure(self): """Redis failure in check_rate_limit should propagate.""" rds = _mock_redis() rds.incr = AsyncMock(side_effect=ConnectionError("redis unavailable")) with pytest.raises(ConnectionError, match="redis unavailable"): await check_rate_limit(rds, "news_api", _now()) @pytest.mark.asyncio async def test_retry_failed_handles_redis_failure(self): """Redis failure during enqueue should propagate.""" pool = _mock_pool() rds = _mock_redis() doc_id = uuid.uuid4() pool.fetch = AsyncMock(return_value=[ {"id": doc_id, "document_type": "news", "ticker": "AAPL"}, ]) rds.set = AsyncMock(return_value=True) rds.rpush = AsyncMock(side_effect=ConnectionError("redis down")) with pytest.raises(ConnectionError, match="redis down"): await retry_failed_extractions(pool, rds) # =========================================================================== # Edge-case unit tests — boundary conditions and rate limiting # Requirements: 1.3, 1.4 # =========================================================================== # --------------------------------------------------------------------------- # get_cadence_for_source — boundary conditions # --------------------------------------------------------------------------- class TestGetCadenceEdgeCases: def test_zero_polling_interval_clamped_to_minimum(self): """Config with polling_interval_seconds=0 should be clamped to 10.""" assert get_cadence_for_source("news_api", {"polling_interval_seconds": 0}) == 10 def test_negative_polling_interval_clamped_to_minimum(self): """Negative interval should be clamped to 10.""" assert get_cadence_for_source("market_api", {"polling_interval_seconds": -50}) == 10 def test_exactly_minimum_polling_interval(self): """Interval of exactly 10 should be accepted as-is.""" assert get_cadence_for_source("market_api", {"polling_interval_seconds": 10}) == 10 def test_none_config_value_uses_default(self): """Config with polling_interval_seconds=None should fall back to default.""" assert get_cadence_for_source("news_api", {"polling_interval_seconds": None}) == DEFAULT_CADENCES["news_api"] def test_empty_config_dict_uses_default(self): """Empty config dict (no polling_interval_seconds key) uses default.""" assert get_cadence_for_source("filings_api", {}) == DEFAULT_CADENCES["filings_api"] def test_float_polling_interval_truncated(self): """Float value should be truncated to int via int().""" assert get_cadence_for_source("news_api", {"polling_interval_seconds": 120.9}) == 120 # --------------------------------------------------------------------------- # compute_backoff — boundary conditions # --------------------------------------------------------------------------- class TestComputeBackoffEdgeCases: def test_negative_retry_count(self): """Negative retry count should still produce a valid backoff (2^negative → fraction, but int floors).""" result = compute_backoff(-1) # 2^min(-1, 8) = 2^-1 = 0.5, so 60 * 0.5 = 30 assert result == 30 def test_exactly_at_cap_boundary(self): """Find the exact retry count where backoff first hits MAX_BACKOFF.""" # DEFAULT_BACKOFF_BASE=60, MAX_BACKOFF=3600 # 60 * 2^6 = 3840 > 3600, so retry_count=6 should hit the cap assert compute_backoff(6) == MAX_BACKOFF def test_just_below_cap(self): """retry_count=5: 60 * 2^5 = 1920, below MAX_BACKOFF.""" assert compute_backoff(5) == DEFAULT_BACKOFF_BASE * 32 # 1920 def test_very_large_retry_count(self): """Very large retry count should still be capped at MAX_BACKOFF.""" assert compute_backoff(1000) == MAX_BACKOFF # --------------------------------------------------------------------------- # is_source_due — boundary conditions # --------------------------------------------------------------------------- class TestIsSourceDueEdgeCases: def test_exactly_at_max_retry_count_not_due(self): """Exactly at MAX_RETRY_COUNT should NOT be due.""" last = _now() - timedelta(seconds=9999) assert not is_source_due( "market_api", None, last, "failed", MAX_RETRY_COUNT, None, _now() ) def test_one_below_max_retry_count_is_due(self): """One below MAX_RETRY_COUNT with no next_retry_at should be due.""" last = _now() - timedelta(seconds=9999) assert is_source_due( "market_api", None, last, "failed", MAX_RETRY_COUNT - 1, None, _now() ) def test_completed_exactly_at_cadence_boundary(self): """Completed exactly at cadence seconds ago should be due (elapsed >= cadence).""" cadence = DEFAULT_CADENCES["market_api"] # 300 last = _now() - timedelta(seconds=cadence) assert is_source_due("market_api", None, last, "completed", 0, None, _now()) def test_completed_one_second_before_cadence_not_due(self): """Completed one second less than cadence ago should NOT be due.""" cadence = DEFAULT_CADENCES["market_api"] # 300 last = _now() - timedelta(seconds=cadence - 1) assert not is_source_due("market_api", None, last, "completed", 0, None, _now()) def test_status_completed_but_completed_at_none_is_due(self): """Status is not None but completed_at is None → should be due (falls through to cadence check with None).""" assert is_source_due("market_api", None, None, "completed", 0, None, _now()) def test_next_retry_at_exactly_now_not_due(self): """next_retry_at exactly equal to now → now < nra is False, so should be due.""" last = _now() - timedelta(seconds=120) assert is_source_due("market_api", None, last, "failed", 1, _now(), _now()) # --------------------------------------------------------------------------- # build_job_payload — edge cases # --------------------------------------------------------------------------- class TestBuildJobPayloadEdgeCases: def test_config_as_json_string(self): """Config stored as a JSON string should be coerced to a dict.""" src = _make_source(config='{"polling_interval_seconds": 120}') src["config"] = '{"polling_interval_seconds": 120}' job = build_job_payload(src, [], _now()) assert job["config"] == {"polling_interval_seconds": 120} def test_config_as_invalid_json_string(self): """Invalid JSON string config should fall back to empty dict.""" src = _make_source() src["config"] = "not-json" job = build_job_payload(src, [], _now()) assert job["config"] == {} def test_empty_ticker(self): """Source with empty ticker should produce empty string in payload.""" src = _make_source(ticker="") job = build_job_payload(src, [], _now()) assert job["ticker"] == "" def test_zero_credibility_score(self): """Credibility score of 0.0 should be preserved (not treated as falsy → 0.5).""" src = _make_source(credibility_score=0.0) # 0.0 is falsy in Python, so this tests the boundary job = build_job_payload(src, [], _now()) # The implementation uses `if source["credibility_score"]` which treats 0.0 as falsy # This documents the actual behavior: 0.0 → 0.5 assert job["credibility_score"] == 0.5 def test_many_aliases(self): """Multiple aliases should all be included in the payload.""" src = _make_source() aliases = ["Apple Inc.", "Apple Computer", "AAPL", "Apple"] job = build_job_payload(src, aliases, _now()) assert job["aliases"] == aliases # --------------------------------------------------------------------------- # check_rate_limit — boundary and edge cases # --------------------------------------------------------------------------- class TestCheckRateLimitEdgeCases: @pytest.mark.asyncio async def test_exactly_at_per_type_limit_allowed(self): """Count exactly equal to the limit should be allowed (only > limit blocks).""" rds = _mock_redis() limit = DEFAULT_RATE_LIMITS["news_api"] rds.incr = AsyncMock(return_value=limit) result = await check_rate_limit(rds, "news_api", _now()) assert result is True @pytest.mark.asyncio async def test_one_over_per_type_limit_blocked(self): """Count one over the limit should be blocked.""" rds = _mock_redis() limit = DEFAULT_RATE_LIMITS["news_api"] rds.incr = AsyncMock(return_value=limit + 1) result = await check_rate_limit(rds, "news_api", _now()) assert result is False @pytest.mark.asyncio async def test_polygon_exactly_at_global_limit_allowed(self): """Polygon global count exactly at limit should be allowed.""" rds = _mock_redis() call_count = 0 async def _incr(key): nonlocal call_count call_count += 1 if call_count == 1: return 1 # per-type OK return POLYGON_GLOBAL_RATE_LIMIT # exactly at global limit rds.incr = AsyncMock(side_effect=_incr) result = await check_rate_limit(rds, "market_api", _now()) assert result is True rds.decr.assert_not_called() @pytest.mark.asyncio async def test_polygon_one_over_global_limit_blocked_and_decrements(self): """Polygon global count one over limit should block and decrement per-type counter.""" rds = _mock_redis() call_count = 0 async def _incr(key): nonlocal call_count call_count += 1 if call_count == 1: return 1 # per-type OK return POLYGON_GLOBAL_RATE_LIMIT + 1 # one over global limit rds.incr = AsyncMock(side_effect=_incr) result = await check_rate_limit(rds, "market_api", _now()) assert result is False rds.decr.assert_called_once() @pytest.mark.asyncio async def test_unknown_source_type_uses_default_limit(self): """Unknown source type should use the fallback limit of 30.""" rds = _mock_redis() rds.incr = AsyncMock(return_value=30) # exactly at default limit result = await check_rate_limit(rds, "unknown_type", _now()) assert result is True @pytest.mark.asyncio async def test_unknown_source_type_over_default_blocked(self): """Unknown source type at 31 should be blocked (default limit is 30).""" rds = _mock_redis() rds.incr = AsyncMock(return_value=31) result = await check_rate_limit(rds, "unknown_type", _now()) assert result is False @pytest.mark.asyncio async def test_expire_not_called_on_subsequent_increments(self): """Expire should only be called when incr returns 1 (first increment).""" rds = _mock_redis() rds.incr = AsyncMock(return_value=5) # not the first increment await check_rate_limit(rds, "filings_api", _now()) rds.expire.assert_not_called() @pytest.mark.asyncio async def test_news_api_polygon_global_check(self): """news_api is a Polygon type — should also check global limit.""" rds = _mock_redis() call_count = 0 async def _incr(key): nonlocal call_count call_count += 1 return 1 # both counters OK rds.incr = AsyncMock(side_effect=_incr) result = await check_rate_limit(rds, "news_api", _now()) assert result is True # Should have called incr twice: per-type + global assert call_count == 2 # --------------------------------------------------------------------------- # schedule_cycle — empty source list edge case # --------------------------------------------------------------------------- class TestScheduleCycleEdgeCases: @pytest.mark.asyncio async def test_empty_source_lists_returns_zero(self): """When all source queries return empty lists, enqueued count should be 0.""" pool = _mock_pool() rds = _mock_redis() pool.fetch = AsyncMock(side_effect=[ [], # fetch_active_sources [], # fetch_macro_sources [], # fetch_global_market_sources ]) enqueued = await schedule_cycle(pool, rds) assert enqueued == 0 rds.rpush.assert_not_called() @pytest.mark.asyncio async def test_multiple_sources_mixed_due_and_not_due(self): """Mix of due and not-due sources: only due ones get enqueued.""" pool = _mock_pool() rds = _mock_redis() src1 = _make_source(source_id="src-1", ticker="AAPL") src2 = _make_source(source_id="src-2", ticker="MSFT") pool.fetch = AsyncMock(side_effect=[ [src1, src2], # fetch_active_sources [], # fetch_macro_sources [], # fetch_global_market_sources [], # aliases for src-1 ]) # src-1 never run (due), src-2 recently completed (not due) recent_run = _make_last_run( status="completed", completed_at=datetime.now(tz=timezone.utc) - timedelta(seconds=10), ) call_count = 0 async def _fetchrow(*args, **kwargs): nonlocal call_count call_count += 1 if call_count == 1: return None # src-1: never run → due return recent_run # src-2: recently completed → not due pool.fetchrow = AsyncMock(side_effect=_fetchrow) rds.incr = AsyncMock(return_value=1) enqueued = await schedule_cycle(pool, rds) assert enqueued == 1