"""Scheduler - triggers ingestion cycles for tracked symbols and sources. Polls the symbol registry for active companies and their configured sources, respects per-source polling cadences and backoff windows, coordinates rate limits across source types, and enqueues ingestion jobs for downstream workers. Requirements: 2.1, 2.2, 2.3, 2.4, 2.5 """ import asyncio import json import logging from datetime import datetime from typing import Any, Optional import asyncpg import redis.asyncio as aioredis from services.shared.config import load_config from services.shared.db import get_pg_pool, get_redis from services.shared.logging import setup_logging from services.shared.redis_keys import ( QUEUE_INGESTION, lock_key, queue_key, rate_limit_key, ) logger = logging.getLogger("scheduler") def _ensure_dict(val: Any) -> Optional[dict]: """Coerce a JSONB value (dict or JSON string) to a Python dict.""" if val is None: return None if isinstance(val, dict): return val if isinstance(val, str): try: parsed = json.loads(val) return parsed if isinstance(parsed, dict) else None except (json.JSONDecodeError, TypeError): return None return None # Default polling cadences by source class (seconds). # Individual sources can override via config.polling_interval_seconds. DEFAULT_CADENCES: dict[str, int] = { "market_api": 900, "news_api": 300, "filings_api": 3600, "web_scrape": 1800, "broker": 30, "macro_news": 600, } # Default rate limits per source type (requests per minute) DEFAULT_RATE_LIMITS: dict[str, int] = { "market_api": 5, "news_api": 20, "filings_api": 10, "web_scrape": 10, "broker": 60, "macro_news": 10, } # How long to wait before retrying a failed source (seconds) DEFAULT_BACKOFF_BASE: int = 60 MAX_BACKOFF: int = 3600 MAX_RETRY_COUNT: int = 10 # Main loop interval (seconds) SCHEDULER_TICK: int = 15 def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int: """Return the polling interval for a source. Uses the source's config.polling_interval_seconds if set, otherwise falls back to the default cadence for the source type. """ if config and "polling_interval_seconds" in config: try: return max(10, int(config["polling_interval_seconds"])) except (ValueError, TypeError): pass return DEFAULT_CADENCES.get(source_type, 600) def compute_backoff(retry_count: int) -> int: """Exponential backoff with a cap. Returns seconds to wait.""" delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8)) return min(delay, MAX_BACKOFF) def is_source_due( source_type: str, source_config: Optional[dict[str, Any]], last_completed_at: Optional[datetime], last_status: Optional[str], retry_count: int, next_retry_at: Optional[datetime], now: datetime, ) -> bool: """Determine whether a source is due for its next polling cycle. Checks: - If the source has never run, it is due. - If the last run failed and we have a next_retry_at in the future, skip. - If the last run failed and retry_count exceeds max, skip (needs manual reset). - Otherwise, check if enough time has elapsed since the last completed run. """ # Never run before — always due if last_completed_at is None and last_status is None: return True # If last run failed, respect backoff if last_status == "failed": if retry_count >= MAX_RETRY_COUNT: return False if next_retry_at and now < next_retry_at.replace(tzinfo=None): return False # Backoff elapsed or no next_retry_at set — allow retry return True # If currently running, don't double-schedule if last_status == "running": return False # Normal cadence check if last_completed_at is None: return True cadence = get_cadence_for_source(source_type, source_config) elapsed = (now - last_completed_at.replace(tzinfo=None)).total_seconds() return elapsed >= cadence def build_job_payload( source: Any, aliases: list[str], now: datetime, ) -> dict[str, Any]: """Build the ingestion job payload for a source.""" return { "source_id": str(source["source_id"]), "company_id": str(source["company_id"]) if source.get("company_id") else None, "ticker": source.get("ticker") or "", "legal_name": source.get("legal_name") or "", "aliases": aliases, "source_type": source["source_type"], "source_name": source["source_name"], "config": _ensure_dict(source["config"]) or {}, "credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5, "scheduled_at": now.isoformat(), } async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool: """Acquire a distributed lock. Returns True if acquired.""" return await rds.set(lock_key(name), "1", nx=True, ex=ttl) async def release_lock(rds: aioredis.Redis, name: str) -> None: """Release a distributed lock.""" await rds.delete(lock_key(name)) async def check_rate_limit( rds: aioredis.Redis, source_type: str, now: datetime, max_per_minute: Optional[int] = None, ) -> bool: """Check whether the source type is within its rate limit window. Returns True if the request is allowed, False if rate-limited. """ limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30) window = now.strftime("%Y%m%d%H%M") key = rate_limit_key(source_type, window) count = await rds.incr(key) if count == 1: await rds.expire(key, 120) return count <= limit async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch all active company-specific sources joined with their active companies.""" return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score, c.ticker, c.legal_name FROM sources s JOIN companies c ON s.company_id = c.id WHERE s.active = TRUE AND c.active = TRUE AND s.source_type != 'macro_news' ORDER BY s.source_type, c.ticker""" ) async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch all active macro news sources. Macro sources are not company-specific — they have source_type='macro_news' and may have company_id NULL. They are scheduled independently from company-specific sources. Requirements: 1.1 """ return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score FROM sources s WHERE s.active = TRUE AND s.source_type = 'macro_news' ORDER BY s.source_name""" ) async def fetch_global_market_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch active market sources that are not company-specific. These are sources like the grouped daily endpoint that fetch data for all tickers in a single API call. They have company_id IS NULL and source_type = 'market_api'. """ return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score FROM sources s WHERE s.active = TRUE AND s.source_type = 'market_api' AND s.company_id IS NULL ORDER BY s.source_name""" ) async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]: """Fetch all aliases for a company.""" rows = await pool.fetch( "SELECT alias FROM company_aliases WHERE company_id = $1", company_id, ) return [r["alias"] for r in rows] async def fetch_last_run( pool: asyncpg.Pool, source_id: str ) -> Optional[asyncpg.Record]: """Fetch the most recent ingestion run for a source.""" return await pool.fetchrow( """SELECT status, started_at, completed_at, retry_count, next_retry_at FROM ingestion_runs WHERE source_id = $1 ORDER BY started_at DESC LIMIT 1""", source_id, ) async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int: """One scheduling pass: find due sources and enqueue ingestion jobs. Returns the number of jobs enqueued. """ now = datetime.utcnow() sources = await fetch_active_sources(pool) enqueued = 0 skipped_rate_limit = 0 skipped_not_due = 0 for src in sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) # Check last run status and timing last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue # Rate limit check if not await check_rate_limit(rds, source_type, now): logger.warning( "Rate limit hit for %s, skipping %s/%s", source_type, src["ticker"], src["source_name"], ) skipped_rate_limit += 1 continue # Fetch company aliases for downstream entity matching aliases = await fetch_aliases_for_company(pool, src["company_id"]) job = build_job_payload(src, aliases, now) await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc] enqueued += 1 logger.debug( "Enqueued %s job for %s (%s)", source_type, src["ticker"], src["source_name"], ) # --- Schedule macro news sources (Requirement 1.1) --- macro_sources = await fetch_macro_sources(pool) for src in macro_sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue if not await check_rate_limit(rds, source_type, now): logger.warning( "Rate limit hit for macro_news, skipping %s", src["source_name"], ) skipped_rate_limit += 1 continue job = build_job_payload(src, [], now) await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) enqueued += 1 logger.debug( "Enqueued macro_news job for %s", src["source_name"], ) # --- Schedule global market sources (grouped daily, etc.) --- global_market_sources = await fetch_global_market_sources(pool) for src in global_market_sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) # Use a longer cadence for grouped daily (once per day = 86400s) endpoint = source_config.get("endpoint", "") if endpoint == "grouped_daily": source_config.setdefault("polling_interval_seconds", 86400) last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue if not await check_rate_limit(rds, source_type, now): skipped_rate_limit += 1 continue # Build job with ticker="_MARKET" for global sources job = build_job_payload(src, [], now) job["ticker"] = "_MARKET" await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) enqueued += 1 logger.info("Enqueued grouped daily market data job") logger.info( "Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d", enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources) + len(global_market_sources), ) return enqueued async def main() -> None: config = load_config() setup_logging("scheduler", level=config.log_level, json_output=config.json_logs) pool = await get_pg_pool(config) rds = get_redis(config) logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK) try: while True: try: if await acquire_lock(rds, "scheduler_cycle", ttl=30): try: await schedule_cycle(pool, rds) finally: await release_lock(rds, "scheduler_cycle") except Exception: logger.exception("Scheduler cycle error") await asyncio.sleep(SCHEDULER_TICK) finally: await pool.close() await rds.close() if __name__ == "__main__": asyncio.run(main())