"""Scheduler - triggers ingestion cycles for tracked symbols and sources. Polls the symbol registry for active companies and their configured sources, respects per-source polling cadences and backoff windows, coordinates rate limits across source types, and enqueues ingestion jobs for downstream workers. Requirements: 2.1, 2.2, 2.3, 2.4, 2.5 """ import asyncio import json import logging import os from datetime import datetime, timezone from typing import Any, Optional import asyncpg import redis.asyncio as aioredis from services.shared.config import load_config from services.shared.db import get_pg_pool, get_redis from services.shared.logging import setup_logging from services.shared.redis_keys import ( PIPELINE_ENABLED_KEY, QUEUE_AGGREGATION, QUEUE_EXTRACTION, QUEUE_INGESTION, QUEUE_MACRO_CLASSIFICATION, QUEUE_PREFIX, lock_key, queue_key, rate_limit_key, ) logger = logging.getLogger("scheduler") def _ensure_dict(val: Any) -> Optional[dict]: """Coerce a JSONB value (dict or JSON string) to a Python dict.""" if val is None: return None if isinstance(val, dict): return val if isinstance(val, str): try: parsed = json.loads(val) return parsed if isinstance(parsed, dict) else None except (json.JSONDecodeError, TypeError): return None return None # Default polling cadences by source class (seconds). # Individual sources can override via config.polling_interval_seconds. DEFAULT_CADENCES: dict[str, int] = { "market_api": 300, "news_api": 300, "filings_api": 3600, "web_scrape": 1800, "broker": 30, "macro_news": 600, } # Default rate limits per source type (requests per minute) DEFAULT_RATE_LIMITS: dict[str, int] = { "market_api": 20, "news_api": 20, "filings_api": 10, "web_scrape": 10, "broker": 60, "macro_news": 10, } # Global rate limit across all Polygon-backed source types (requests per minute). # market_api + news_api share a single Polygon API key, so we cap the combined # throughput to stay safely under the plan limit. POLYGON_SOURCE_TYPES: set[str] = {"market_api", "news_api"} POLYGON_GLOBAL_RATE_LIMIT: int = 45 # How long to wait before retrying a failed source (seconds) DEFAULT_BACKOFF_BASE: int = 60 MAX_BACKOFF: int = 3600 MAX_RETRY_COUNT: int = 10 # Main loop interval (seconds) SCHEDULER_TICK: int = 15 # Periodic aggregation: re-aggregate all tickers every N cycles during market hours # 15s tick × 60 cycles = 15 minutes AGGREGATION_CYCLE_INTERVAL: int = 60 def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int: """Return the polling interval for a source. Uses the source's config.polling_interval_seconds if set, otherwise falls back to the default cadence for the source type. """ if config and "polling_interval_seconds" in config: try: return max(10, int(config["polling_interval_seconds"])) except (ValueError, TypeError): pass return DEFAULT_CADENCES.get(source_type, 600) def compute_backoff(retry_count: int) -> int: """Exponential backoff with a cap. Returns seconds to wait.""" delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8)) return min(delay, MAX_BACKOFF) def is_source_due( source_type: str, source_config: Optional[dict[str, Any]], last_completed_at: Optional[datetime], last_status: Optional[str], retry_count: int, next_retry_at: Optional[datetime], now: datetime, ) -> bool: """Determine whether a source is due for its next polling cycle. Checks: - If the source has never run, it is due. - If the last run failed and we have a next_retry_at in the future, skip. - If the last run failed and retry_count exceeds max, skip (needs manual reset). - Otherwise, check if enough time has elapsed since the last completed run. """ # Never run before — always due if last_completed_at is None and last_status is None: return True # If last run failed, respect backoff if last_status == "failed": if retry_count >= MAX_RETRY_COUNT: return False if next_retry_at: # Normalize tz-awareness to match 'now' if now.tzinfo is not None and next_retry_at.tzinfo is None: nra = next_retry_at.replace(tzinfo=timezone.utc) elif now.tzinfo is None and next_retry_at.tzinfo is not None: nra = next_retry_at.replace(tzinfo=None) else: nra = next_retry_at if now < nra: return False # Backoff elapsed or no next_retry_at set — allow retry return True # If currently running, don't double-schedule if last_status == "running": return False # Normal cadence check if last_completed_at is None: return True cadence = get_cadence_for_source(source_type, source_config) # Ensure both datetimes have matching tz-awareness for subtraction if now.tzinfo is not None and last_completed_at.tzinfo is None: last_completed_at = last_completed_at.replace(tzinfo=timezone.utc) elif now.tzinfo is None and last_completed_at.tzinfo is not None: last_completed_at = last_completed_at.replace(tzinfo=None) elapsed = (now - last_completed_at).total_seconds() return elapsed >= cadence def build_job_payload( source: Any, aliases: list[str], now: datetime, ) -> dict[str, Any]: """Build the ingestion job payload for a source.""" return { "source_id": str(source["source_id"]), "company_id": str(source["company_id"]) if source.get("company_id") else None, "ticker": source.get("ticker") or "", "legal_name": source.get("legal_name") or "", "aliases": aliases, "source_type": source["source_type"], "source_name": source["source_name"], "config": _ensure_dict(source["config"]) or {}, "credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5, "scheduled_at": now.isoformat(), } async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool: """Acquire a distributed lock. Returns True if acquired.""" return await rds.set(lock_key(name), "1", nx=True, ex=ttl) async def release_lock(rds: aioredis.Redis, name: str) -> None: """Release a distributed lock.""" await rds.delete(lock_key(name)) async def check_rate_limit( rds: aioredis.Redis, source_type: str, now: datetime, max_per_minute: Optional[int] = None, ) -> bool: """Check whether the source type is within its rate limit window. Enforces two limits: 1. Per-source-type limit (e.g. market_api: 20/min) 2. Global Polygon limit across all Polygon-backed types (45/min combined) Returns True if the request is allowed, False if rate-limited. """ limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30) window = now.strftime("%Y%m%d%H%M") # Per-source-type check key = rate_limit_key(source_type, window) count = await rds.incr(key) if count == 1: await rds.expire(key, 120) if count > limit: return False # Global Polygon check for source types that share the Polygon API key if source_type in POLYGON_SOURCE_TYPES: global_key = rate_limit_key("_polygon_global", window) global_count = await rds.incr(global_key) if global_count == 1: await rds.expire(global_key, 120) if global_count > POLYGON_GLOBAL_RATE_LIMIT: # Roll back the per-type counter since we won't actually make the call await rds.decr(key) return False return True async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch all active company-specific sources joined with their active companies.""" return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score, c.ticker, c.legal_name FROM sources s JOIN companies c ON s.company_id = c.id WHERE s.active = TRUE AND c.active = TRUE AND s.source_type != 'macro_news' ORDER BY s.source_type, c.ticker""" ) async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch all active macro news sources. Macro sources are not company-specific — they have source_type='macro_news' and may have company_id NULL. They are scheduled independently from company-specific sources. Requirements: 1.1 """ return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score FROM sources s WHERE s.active = TRUE AND s.source_type = 'macro_news' ORDER BY s.source_name""" ) async def fetch_global_market_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]: """Fetch active market sources that are not company-specific. These are sources like the grouped daily endpoint that fetch data for all tickers in a single API call. They have company_id IS NULL and source_type = 'market_api'. """ return await pool.fetch( """SELECT s.id AS source_id, s.company_id, s.source_type, s.source_name, s.config, s.credibility_score FROM sources s WHERE s.active = TRUE AND s.source_type = 'market_api' AND s.company_id IS NULL ORDER BY s.source_name""" ) async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]: """Fetch all aliases for a company.""" rows = await pool.fetch( "SELECT alias FROM company_aliases WHERE company_id = $1", company_id, ) return [r["alias"] for r in rows] async def fetch_last_run( pool: asyncpg.Pool, source_id: str ) -> Optional[asyncpg.Record]: """Fetch the most recent ingestion run for a source.""" return await pool.fetchrow( """SELECT status, started_at, completed_at, retry_count, next_retry_at FROM ingestion_runs WHERE source_id = $1 ORDER BY started_at DESC LIMIT 1""", source_id, ) async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int: """One scheduling pass: find due sources and enqueue ingestion jobs. Returns the number of jobs enqueued. """ now = datetime.now(tz=timezone.utc) sources = await fetch_active_sources(pool) enqueued = 0 skipped_rate_limit = 0 skipped_not_due = 0 for src in sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) # Check last run status and timing last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue # Rate limit check if not await check_rate_limit(rds, source_type, now): logger.warning( "Rate limit hit for %s, skipping %s/%s", source_type, src["ticker"], src["source_name"], ) skipped_rate_limit += 1 continue # Fetch company aliases for downstream entity matching aliases = await fetch_aliases_for_company(pool, src["company_id"]) job = build_job_payload(src, aliases, now) await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc] enqueued += 1 logger.debug( "Enqueued %s job for %s (%s)", source_type, src["ticker"], src["source_name"], ) # --- Schedule macro news sources (Requirement 1.1) --- macro_sources = await fetch_macro_sources(pool) for src in macro_sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue if not await check_rate_limit(rds, source_type, now): logger.warning( "Rate limit hit for macro_news, skipping %s", src["source_name"], ) skipped_rate_limit += 1 continue job = build_job_payload(src, [], now) await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) enqueued += 1 logger.debug( "Enqueued macro_news job for %s", src["source_name"], ) # --- Schedule global market sources (grouped daily, etc.) --- global_market_sources = await fetch_global_market_sources(pool) for src in global_market_sources: source_id = src["source_id"] source_type = src["source_type"] source_config = _ensure_dict(src["config"]) # Use a longer cadence for grouped daily (once per day = 86400s) endpoint = source_config.get("endpoint", "") if endpoint == "grouped_daily": source_config.setdefault("polling_interval_seconds", 86400) last_run = await fetch_last_run(pool, source_id) last_completed_at = None last_status = None retry_count = 0 next_retry_at = None if last_run: last_status = last_run["status"] last_completed_at = last_run["completed_at"] or last_run["started_at"] retry_count = last_run["retry_count"] or 0 next_retry_at = last_run["next_retry_at"] if not is_source_due( source_type=source_type, source_config=source_config, last_completed_at=last_completed_at, last_status=last_status, retry_count=retry_count, next_retry_at=next_retry_at, now=now, ): skipped_not_due += 1 continue if not await check_rate_limit(rds, source_type, now): skipped_rate_limit += 1 continue # Build job with ticker="_MARKET" for global sources job = build_job_payload(src, [], now) if endpoint == "intraday_bars": # Expand intraday source into per-ticker jobs for all active companies tickers = await pool.fetch( "SELECT ticker FROM companies WHERE active = TRUE" ) for t_row in tickers: ticker_job = dict(job) ticker_job["ticker"] = t_row["ticker"] await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(ticker_job)) enqueued += 1 logger.info("Enqueued %d intraday bar jobs", len(tickers)) else: job["ticker"] = "_MARKET" await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) enqueued += 1 logger.info("Enqueued grouped daily market data job") logger.info( "Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d", enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources) + len(global_market_sources), ) return enqueued async def enqueue_periodic_aggregation(pool: asyncpg.Pool, rds: aioredis.Redis) -> int: """Enqueue aggregation jobs for all active tickers. Runs periodically to ensure trend data stays fresh even when no new documents are being ingested. During market hours this runs every ~15 minutes; outside market hours it runs every ~60 minutes (for backtesting data continuity). """ # Fetch all active tickers rows = await pool.fetch( "SELECT ticker FROM companies WHERE active = TRUE ORDER BY ticker" ) if not rows: return 0 agg_queue = queue_key(QUEUE_AGGREGATION) count = 0 for row in rows: await rds.rpush(agg_queue, json.dumps({"ticker": row["ticker"]})) count += 1 logger.info("Periodic aggregation: enqueued %d tickers for re-aggregation", count) return count async def main() -> None: config = load_config() setup_logging("scheduler", level=config.log_level, json_output=config.json_logs) pool = await get_pg_pool(config) rds = get_redis(config) logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK) pipeline_key = PIPELINE_ENABLED_KEY # If PIPELINE_DEFAULT_OFF is set, initialize the toggle to OFF on first boot # (only if the key doesn't already exist — preserves manual overrides) if os.getenv("PIPELINE_DEFAULT_OFF", "").lower() in ("1", "true", "yes"): created = await rds.set(pipeline_key, "0", nx=True) if created: logger.info("Pipeline toggle initialized to OFF (PIPELINE_DEFAULT_OFF=true)") recovery_counter = 0 retry_counter = 0 cleanup_counter = 0 aggregation_counter = 0 try: while True: try: # Check pipeline toggle — skip cycle if disabled flag = await rds.get(pipeline_key) if flag == "0": await asyncio.sleep(SCHEDULER_TICK) continue if await acquire_lock(rds, "scheduler_cycle", ttl=30): try: await schedule_cycle(pool, rds) # Run stale document recovery every ~20 cycles (~5 minutes) recovery_counter += 1 if recovery_counter >= 20: recovery_counter = 0 await recover_stale_documents(pool, rds) # Retry extraction failures every ~40 cycles (~10 minutes) retry_counter += 1 if retry_counter >= 40: retry_counter = 0 await retry_failed_extractions(pool, rds) # Run signal cleanup periodically (~25 minutes) cleanup_counter += 1 if cleanup_counter >= CLEANUP_CYCLE_INTERVAL: cleanup_counter = 0 await cleanup_all_tables(pool) # Periodic aggregation during market hours (~15 minutes) aggregation_counter += 1 if aggregation_counter >= AGGREGATION_CYCLE_INTERVAL: aggregation_counter = 0 await enqueue_periodic_aggregation(pool, rds) finally: await release_lock(rds, "scheduler_cycle") except Exception: logger.exception("Scheduler cycle error") await asyncio.sleep(SCHEDULER_TICK) finally: await pool.close() await rds.close() # How long a document can sit in "parsed" before we consider it orphaned # Must be longer than the expected queue drain time to avoid re-enqueuing # docs that are already queued but not yet processed. STALE_PARSED_THRESHOLD_MINUTES: int = 240 # How long after an extraction failure before we retry EXTRACTION_FAILED_RETRY_MINUTES: int = 60 # Redis set key for tracking enqueued doc IDs (prevents duplicate enqueuing) _ENQUEUED_SET = f"{QUEUE_PREFIX}:enqueued" # How long an enqueued marker lives before it can be re-enqueued (seconds) _ENQUEUED_TTL = 14400 # 4 hours — matches STALE_PARSED_THRESHOLD_MINUTES async def _enqueue_if_new( rds: aioredis.Redis, queue: str, document_id: str, ticker: str, ) -> bool: """Push a job onto *queue* only if *document_id* isn't already tracked. Uses a Redis SET with per-member expiry (via a separate key) to prevent the same document from being enqueued multiple times by recovery sweeps. Returns True if enqueued, False if skipped as duplicate. """ marker_key = f"{_ENQUEUED_SET}:{document_id}" # SET NX returns True only if the key was created (not already present) added = await rds.set(marker_key, "1", nx=True, ex=_ENQUEUED_TTL) if not added: return False await rds.rpush(queue, json.dumps({ "document_id": document_id, "ticker": ticker, })) return True async def recover_stale_documents(pool: asyncpg.Pool, rds: aioredis.Redis) -> int: """Re-enqueue documents stuck in 'parsed' status for extraction. Documents can get orphaned when Redis loses queue entries (pod restart, OOM, etc.). This sweep catches any document that has been in 'parsed' status for longer than STALE_PARSED_THRESHOLD_MINUTES and re-enqueues it for extraction. Returns the number of documents re-enqueued. """ rows = await pool.fetch( """SELECT d.id, d.document_type, dcm.ticker FROM documents d LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id WHERE d.status = 'parsed' AND d.updated_at < NOW() - INTERVAL '1 minute' * $1 AND NOT EXISTS ( SELECT 1 FROM global_events ge WHERE ge.source_document_id = d.id ) ORDER BY d.created_at ASC LIMIT 100""", STALE_PARSED_THRESHOLD_MINUTES, ) if not rows: return 0 enqueued = 0 doc_ids = [] for row in rows: doc_type = row["document_type"] if doc_type == "macro_event": target = queue_key(QUEUE_MACRO_CLASSIFICATION) else: target = queue_key(QUEUE_EXTRACTION) added = await _enqueue_if_new(rds, target, str(row["id"]), row["ticker"] or "") if added: doc_ids.append(row["id"]) enqueued += 1 # Touch updated_at so these docs won't be re-enqueued until the threshold passes again if doc_ids: await pool.execute( "UPDATE documents SET updated_at = NOW() WHERE id = ANY($1::uuid[])", doc_ids, ) logger.info("Recovered %d stale parsed documents for extraction", enqueued) return enqueued async def retry_failed_extractions(pool: asyncpg.Pool, rds: aioredis.Redis) -> int: """Re-enqueue documents stuck in 'extraction_failed' for another attempt. Resets status to 'parsed', deletes the failed intelligence row so the extractor treats them as fresh, and pushes them onto the extraction queue. Only retries documents whose last attempt was at least EXTRACTION_FAILED_RETRY_MINUTES ago to avoid tight retry loops. Returns the number of documents re-enqueued. """ rows = await pool.fetch( """SELECT d.id, d.document_type, dcm.ticker FROM documents d LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id WHERE d.status = 'extraction_failed' AND d.updated_at < NOW() - INTERVAL '1 minute' * $1 ORDER BY d.updated_at ASC LIMIT 100""", EXTRACTION_FAILED_RETRY_MINUTES, ) if not rows: return 0 enqueued = 0 doc_ids = [] for row in rows: doc_type = row["document_type"] if doc_type == "macro_event": target = queue_key(QUEUE_MACRO_CLASSIFICATION) else: target = queue_key(QUEUE_EXTRACTION) added = await _enqueue_if_new(rds, target, str(row["id"]), row["ticker"] or "") if added: doc_ids.append(row["id"]) enqueued += 1 if doc_ids: # Delete failed intelligence rows so extractor starts fresh await pool.execute( """DELETE FROM document_intelligence WHERE document_id = ANY($1::uuid[]) AND validation_status = 'failed'""", doc_ids, ) # Reset status to 'parsed' and touch updated_at await pool.execute( """UPDATE documents SET status = 'parsed', updated_at = NOW() WHERE id = ANY($1::uuid[])""", doc_ids, ) logger.info("Retried %d extraction-failed documents", enqueued) return enqueued # How often to run competitive signal cleanup (every ~100 cycles = ~25 minutes) CLEANUP_CYCLE_INTERVAL: int = 100 # Keep competitive signals for this many days COMPETITIVE_SIGNAL_RETENTION_DAYS: int = 30 async def cleanup_old_signals(pool: asyncpg.Pool) -> int: """Delete competitive signal records older than the retention window. Prevents the competitive_signal_records table from growing unbounded. Returns the number of rows deleted. """ result = await pool.execute( "DELETE FROM competitive_signal_records WHERE computed_at < NOW() - INTERVAL '1 day' * $1", COMPETITIVE_SIGNAL_RETENTION_DAYS, ) # result is like "DELETE 1234" count = int(result.split()[-1]) if result else 0 if count > 0: logger.info("Cleaned up %d old competitive signal records", count) return count async def cleanup_all_tables(pool: asyncpg.Pool) -> None: """Run retention cleanup across all tables that grow unbounded. Called periodically by the scheduler (~every 25 minutes). Each table has its own retention window. """ cleanups = [ # (table, time_column, retention_days, description) ("competitive_signal_records", "computed_at", 30, "competitive signals"), ("ingestion_runs", "started_at", 14, "ingestion runs"), ("trading_decisions", "created_at", 90, "trading decisions"), ("risk_evaluations", "evaluated_at", 30, "risk evaluations"), ("audit_events", "created_at", 90, "audit events"), ("macro_impact_records", "computed_at", 60, "macro impact records"), ("recommendation_evidence", "created_at", 60, "recommendation evidence"), ("recommendations", "generated_at", 30, "old recommendations"), ("order_events", "created_at", 90, "order events"), ("model_performance_metrics", "recorded_at", 30, "model metrics"), ] total = 0 for table, col, days, desc in cleanups: try: result = await pool.execute( f"DELETE FROM {table} WHERE {col} < NOW() - INTERVAL '1 day' * $1", # noqa: S608 days, ) count = int(result.split()[-1]) if result else 0 if count > 0: total += count logger.info("Cleaned %d %s (>%dd old)", count, desc, days) except Exception: # Table might not exist or column name might differ — skip pass if total > 0: logger.info("Total cleanup: %d rows deleted across all tables", total) if __name__ == "__main__": asyncio.run(main())