"""Scheduler - triggers ingestion cycles for tracked symbols and sources. Polls the symbol registry for active companies and their configured sources, respects per-source polling cadences and backoff windows, coordinates rate limits across source types, and enqueues ingestion jobs for downstream workers. Requirements: 2.1, 2.2, 2.3, 2.4, 2.5 """ import asyncio import json import logging from datetime import datetime from typing import Any, Optional import asyncpg import redis.asyncio as aioredis from services.shared.config import load_config from services.shared.db import get_pg_pool, get_redis from services.shared.logging import setup_logging from services.shared.redis_keys import ( QUEUE_EXTRACTION, QUEUE_INGESTION, QUEUE_MACRO_CLASSIFICATION, lock_key, queue_key, rate_limit_key, ) logger = logging.getLogger("scheduler") def _ensure_dict(val: Any) -> Optional[dict]: """Coerce a JSONB value (dict or JSON string) to a Python dict.""" if val is None: return None if isinstance(val, dict): return val if isinstance(val, str): try: parsed = json.loads(val) return parsed if isinstance(parsed, dict) else None except (json.JSONDecodeError, TypeError): return None return None # Default polling cadences by source class (seconds). # Individual sources can override via config.polling_interval_seconds. DEFAULT_CADENCES: dict[str, int] = { "market_api": 300, "news_api": 300, "filings_api": 3600, "web_scrape": 1800, "broker": 30, "macro_news": 600, } # Default rate limits per source type (requests per minute) DEFAULT_RATE_LIMITS: dict[str, int] = { "market_api": 20, "news_api": 20, "filings_api": 10, "web_scrape": 10, "broker": 60, "macro_news": 10, } # Global rate limit across all Polygon-backed source types (requests per minute). # market_api + news_api share a single Polygon API key, so we cap the combined # throughput to stay safely under the plan limit. POLYGON_SOURCE_TYPES: set[str] = {"market_api", "news_api"} POLYGON_GLOBAL_RATE_LIMIT: int = 45 # How long to wait before retrying a failed source (seconds) DEFAULT_BACKOFF_BASE: int = 60 MAX_BACKOFF: int = 3600 MAX_RETRY_COUNT: int = 10 # Main loop interval (seconds) SCHEDULER_TICK: int = 15 def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int: """Return the polling interval for a source. Uses the source's config.polling_interval_seconds if set, otherwise falls back to the default cadence for the source type. """ if config and "polling_interval_seconds" in config: try: return max(10, int(config["polling_interval_seconds"])) except (ValueError, TypeError): pass return DEFAULT_CADENCES.get(source_type, 600) def compute_backoff(retry_count: int) -> int: """Exponential backoff with a cap. Returns seconds to wait.""" delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8)) return min(delay, MAX_BACKOFF) def is_source_due( source_type: str, source_config: Optional[dict[str, Any]], last_completed_at: Optional[datetime], last_status: Optional[str], retry_count: int, next_retry_at: Optional[datetime], now: datetime, ) -> bool: """Determine whether a source is due for its next polling cycle. Checks: - If the source has never run, it is due. - If the last run failed and we have a next_retry_at in the future, skip. - If the last run failed and retry_count exceeds max, skip (needs manual reset). - Otherwise, check if enough time has elapsed since the last completed run. """ # Never run before — always due if last_completed_at is None and last_status is None: return True # If last run failed, respect backoff if last_status == "failed": if retry_count >= MAX_RETRY_COUNT: return False if next_retry_at and now < next_retry_at.replace(tzinfo=None): return False # Backoff elapsed or no next_retry_at set — allow retry return True # If currently running, don't double-schedule if last_status == "running": return False # Normal cadence check if last_completed_at is None: return True cadence = get_cadence_for_source(source_type, source_config) elapsed = (now - last_completed_at.replace(tzinfo=None)).total_seconds() return elapsed >= cadence def build_job_payload( source: Any, aliases: list[str], now: datetime, ) -> dict[str, Any]: """Build the ingestion job payload for a source.""" return { "source_id": str(source["source_id"]), "company_id": str(source["company_id"]) if source.get("company_id") else None, "ticker": source.get("ticker") or "", "legal_name": source.get("legal_name") or "", "aliases": aliases, "source_type": source["source_type"], "source_name": source["source_name"], "config": _ensure_dict(source["config"]) or {}, "credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5, "scheduled_at": now.isoformat(), } async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool: """Acquire a distributed lock. Returns True if acquired.""" return await rds.set(lock_key(name), "1", nx=True, ex=ttl) async def release_lock(rds: aioredis.Redis, name: str) -> None: """Release a distributed lock.""" await rds.delete(lock_key(name)) async def check_rate_limit( rds: aioredis.Redis, source_type: str, now: datetime, max_per_minute: Optional[int] = None, ) -> bool: """Check whether the source type is within its rate limit window. Enforces two limits: 1. Per-source-type limit (e.g. market_api: 20/min) 2. Global Polygon limit across all Polygon-backed types (45/min combined) Returns True if the request is allowed, False if rate-limited. """ limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30) window = now.strftime("%Y%m%d%H%M") # Per-source-type check key = rate_limit_key(source_type, window) count = await rds.incr(key) if count == 1: await rds.expire(key, 120) if count > limit: return False # Global Polygon check for source types that share the Polygon API key if source_type in POLYGON_SOURCE_TYPES: global_key = rate_limit_key("_polygon_global", window) global_count = await rds.incr(global_key) if global_count == 1: await rds.expire(global_key, 120) if global_count > POLYGON_GLOBAL_RATE_LIMIT: # Roll back the per-type counter since we won't actually make the call await rds.decr(key) return False return True async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch all active company-specific sources joined with their active companies.""" return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score, c.ticker, c.legal_name FROM sources s JOIN companies c ON s.company_id = c.id WHERE s.active = TRUE AND c.active = TRUE AND s.source_type != 'macro_news' ORDER BY s.source_type, c.ticker""" ) async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch all active macro news sources. Macro sources are not company-specific — they have source_type='macro_news' and may have company_id NULL. They are scheduled independently from company-specific sources. Requirements: 1.1 """ return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score FROM sources s WHERE s.active = TRUE AND s.source_type = 'macro_news' ORDER BY s.source_name""" ) async def fetch_global_market_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch active market sources that are not company-specific. These are sources like the grouped daily endpoint that fetch data for all tickers in a single API call. They have company_id IS NULL and source_type = 'market_api'. """ return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score FROM sources s WHERE s.active = TRUE AND s.source_type = 'market_api' AND s.company_id IS NULL ORDER BY s.source_name""" ) async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]: """Fetch all aliases for a company.""" rows = await pool.fetch( "SELECT alias FROM company_aliases WHERE company_id = $1", company_id, ) return [r["alias"] for r in rows] async def fetch_last_run( pool: asyncpg.Pool, source_id: str ) -> Optional[asyncpg.Record]: """Fetch the most recent ingestion run for a source.""" return await pool.fetchrow( """SELECT status, started_at, completed_at, retry_count, next_retry_at FROM ingestion_runs WHERE source_id = $1 ORDER BY started_at DESC LIMIT 1""", source_id, ) async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int: """One scheduling pass: find due sources and enqueue ingestion jobs. Returns the number of jobs enqueued. """ now = datetime.utcnow() sources = await fetch_active_sources(pool) enqueued = 0 skipped_rate_limit = 0 skipped_not_due = 0 for src in sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) # Check last run status and timing last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue # Rate limit check if not await check_rate_limit(rds, source_type, now): logger.warning( "Rate limit hit for %s, skipping %s/%s", source_type, src["ticker"], src["source_name"], ) skipped_rate_limit += 1 continue # Fetch company aliases for downstream entity matching aliases = await fetch_aliases_for_company(pool, src["company_id"]) job = build_job_payload(src, aliases, now) await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc] enqueued += 1 logger.debug( "Enqueued %s job for %s (%s)", source_type, src["ticker"], src["source_name"], ) # --- Schedule macro news sources (Requirement 1.1) --- macro_sources = await fetch_macro_sources(pool) for src in macro_sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue if not await check_rate_limit(rds, source_type, now): logger.warning( "Rate limit hit for macro_news, skipping %s", src["source_name"], ) skipped_rate_limit += 1 continue job = build_job_payload(src, [], now) await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) enqueued += 1 logger.debug( "Enqueued macro_news job for %s", src["source_name"], ) # --- Schedule global market sources (grouped daily, etc.) --- global_market_sources = await fetch_global_market_sources(pool) for src in global_market_sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) # Use a longer cadence for grouped daily (once per day = 86400s) endpoint = source_config.get("endpoint", "") if endpoint == "grouped_daily": source_config.setdefault("polling_interval_seconds", 86400) last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue if not await check_rate_limit(rds, source_type, now): skipped_rate_limit += 1 continue # Build job with ticker="_MARKET" for global sources job = build_job_payload(src, [], now) job["ticker"] = "_MARKET" await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) enqueued += 1 logger.info("Enqueued grouped daily market data job") logger.info( "Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d", enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources) + len(global_market_sources), ) return enqueued async def main() -> None: config = load_config() setup_logging("scheduler", level=config.log_level, json_output=config.json_logs) pool = await get_pg_pool(config) rds = get_redis(config) logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK) recovery_counter = 0 cleanup_counter = 0 try: while True: try: if await acquire_lock(rds, "scheduler_cycle", ttl=30): try: await schedule_cycle(pool, rds) # Run stale document recovery every ~20 cycles (~5 minutes) recovery_counter += 1 if recovery_counter >= 20: recovery_counter = 0 await recover_stale_documents(pool, rds) # Run signal cleanup periodically (~25 minutes) cleanup_counter += 1 if cleanup_counter >= CLEANUP_CYCLE_INTERVAL: cleanup_counter = 0 await cleanup_all_tables(pool) finally: await release_lock(rds, "scheduler_cycle") except Exception: logger.exception("Scheduler cycle error") await asyncio.sleep(SCHEDULER_TICK) finally: await pool.close() await rds.close() # How long a document can sit in "parsed" before we consider it orphaned STALE_PARSED_THRESHOLD_MINUTES: int = 30 async def recover_stale_documents(pool: asyncpg.Pool, rds: aioredis.Redis) -> int: """Re-enqueue documents stuck in 'parsed' status for extraction. Documents can get orphaned when Redis loses queue entries (pod restart, OOM, etc.). This sweep catches any document that has been in 'parsed' status for longer than STALE_PARSED_THRESHOLD_MINUTES and re-enqueues it for extraction. Returns the number of documents re-enqueued. """ rows = await pool.fetch( """SELECT d.id, d.document_type, dcm.ticker FROM documents d LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id WHERE d.status = 'parsed' AND d.updated_at < NOW() - INTERVAL '1 minute' * $1 AND NOT EXISTS ( SELECT 1 FROM global_events ge WHERE ge.source_document_id = d.id ) ORDER BY d.created_at ASC LIMIT 100""", STALE_PARSED_THRESHOLD_MINUTES, ) if not rows: return 0 enqueued = 0 for row in rows: doc_type = row["document_type"] if doc_type == "macro_event": target = queue_key(QUEUE_MACRO_CLASSIFICATION) else: target = queue_key(QUEUE_EXTRACTION) await rds.rpush(target, json.dumps({ "document_id": str(row["id"]), "ticker": row["ticker"] or "", })) enqueued += 1 logger.info("Recovered %d stale parsed documents for extraction", enqueued) return enqueued # How often to run competitive signal cleanup (every ~100 cycles = ~25 minutes) CLEANUP_CYCLE_INTERVAL: int = 100 # Keep competitive signals for this many days COMPETITIVE_SIGNAL_RETENTION_DAYS: int = 30 async def cleanup_old_signals(pool: asyncpg.Pool) -> int: """Delete competitive signal records older than the retention window. Prevents the competitive_signal_records table from growing unbounded. Returns the number of rows deleted. """ result = await pool.execute( "DELETE FROM competitive_signal_records WHERE computed_at < NOW() - INTERVAL '1 day' * $1", COMPETITIVE_SIGNAL_RETENTION_DAYS, ) # result is like "DELETE 1234" count = int(result.split()[-1]) if result else 0 if count > 0: logger.info("Cleaned up %d old competitive signal records", count) return count async def cleanup_all_tables(pool: asyncpg.Pool) -> None: """Run retention cleanup across all tables that grow unbounded. Called periodically by the scheduler (~every 25 minutes). Each table has its own retention window. """ cleanups = [ # (table, time_column, retention_days, description) ("competitive_signal_records", "computed_at", 30, "competitive signals"), ("ingestion_runs", "started_at", 14, "ingestion runs"), ("trading_decisions", "created_at", 90, "trading decisions"), ("risk_evaluations", "evaluated_at", 30, "risk evaluations"), ("audit_events", "created_at", 90, "audit events"), ("macro_impact_records", "computed_at", 60, "macro impact records"), ("recommendation_evidence", "created_at", 60, "recommendation evidence"), ("recommendations", "generated_at", 30, "old recommendations"), ("order_events", "created_at", 90, "order events"), ("model_performance_metrics", "recorded_at", 30, "model metrics"), ] total = 0 for table, col, days, desc in cleanups: try: result = await pool.execute( f"DELETE FROM {table} WHERE {col} < NOW() - INTERVAL '1 day' * $1", # noqa: S608 days, ) count = int(result.split()[-1]) if result else 0 if count > 0: total += count logger.info("Cleaned %d %s (>%dd old)", count, desc, days) except Exception: # Table might not exist or column name might differ — skip pass if total > 0: logger.info("Total cleanup: %d rows deleted across all tables", total) if __name__ == "__main__": asyncio.run(main())