627 lines
21 KiB
Python
627 lines
21 KiB
Python
"""Scheduler - triggers ingestion cycles for tracked symbols and sources.
|
|
|
|
Polls the symbol registry for active companies and their configured sources,
|
|
respects per-source polling cadences and backoff windows, coordinates rate
|
|
limits across source types, and enqueues ingestion jobs for downstream workers.
|
|
|
|
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Optional
|
|
|
|
import asyncpg
|
|
import redis.asyncio as aioredis
|
|
|
|
from services.shared.config import load_config
|
|
from services.shared.db import get_pg_pool, get_redis
|
|
from services.shared.logging import setup_logging
|
|
from services.shared.redis_keys import (
|
|
QUEUE_EXTRACTION,
|
|
QUEUE_INGESTION,
|
|
QUEUE_MACRO_CLASSIFICATION,
|
|
lock_key,
|
|
queue_key,
|
|
rate_limit_key,
|
|
)
|
|
|
|
logger = logging.getLogger("scheduler")
|
|
|
|
|
|
def _ensure_dict(val: Any) -> Optional[dict]:
|
|
"""Coerce a JSONB value (dict or JSON string) to a Python dict."""
|
|
if val is None:
|
|
return None
|
|
if isinstance(val, dict):
|
|
return val
|
|
if isinstance(val, str):
|
|
try:
|
|
parsed = json.loads(val)
|
|
return parsed if isinstance(parsed, dict) else None
|
|
except (json.JSONDecodeError, TypeError):
|
|
return None
|
|
return None
|
|
|
|
# Default polling cadences by source class (seconds).
|
|
# Individual sources can override via config.polling_interval_seconds.
|
|
DEFAULT_CADENCES: dict[str, int] = {
|
|
"market_api": 300,
|
|
"news_api": 300,
|
|
"filings_api": 3600,
|
|
"web_scrape": 1800,
|
|
"broker": 30,
|
|
"macro_news": 600,
|
|
}
|
|
|
|
# Default rate limits per source type (requests per minute)
|
|
DEFAULT_RATE_LIMITS: dict[str, int] = {
|
|
"market_api": 20,
|
|
"news_api": 20,
|
|
"filings_api": 10,
|
|
"web_scrape": 10,
|
|
"broker": 60,
|
|
"macro_news": 10,
|
|
}
|
|
|
|
# Global rate limit across all Polygon-backed source types (requests per minute).
|
|
# market_api + news_api share a single Polygon API key, so we cap the combined
|
|
# throughput to stay safely under the plan limit.
|
|
POLYGON_SOURCE_TYPES: set[str] = {"market_api", "news_api"}
|
|
POLYGON_GLOBAL_RATE_LIMIT: int = 45
|
|
|
|
# How long to wait before retrying a failed source (seconds)
|
|
DEFAULT_BACKOFF_BASE: int = 60
|
|
MAX_BACKOFF: int = 3600
|
|
MAX_RETRY_COUNT: int = 10
|
|
|
|
# Main loop interval (seconds)
|
|
SCHEDULER_TICK: int = 15
|
|
|
|
|
|
def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int:
|
|
"""Return the polling interval for a source.
|
|
|
|
Uses the source's config.polling_interval_seconds if set,
|
|
otherwise falls back to the default cadence for the source type.
|
|
"""
|
|
if config and "polling_interval_seconds" in config:
|
|
try:
|
|
return max(10, int(config["polling_interval_seconds"]))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
return DEFAULT_CADENCES.get(source_type, 600)
|
|
|
|
|
|
def compute_backoff(retry_count: int) -> int:
|
|
"""Exponential backoff with a cap. Returns seconds to wait."""
|
|
delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8))
|
|
return min(delay, MAX_BACKOFF)
|
|
|
|
|
|
def is_source_due(
|
|
source_type: str,
|
|
source_config: Optional[dict[str, Any]],
|
|
last_completed_at: Optional[datetime],
|
|
last_status: Optional[str],
|
|
retry_count: int,
|
|
next_retry_at: Optional[datetime],
|
|
now: datetime,
|
|
) -> bool:
|
|
"""Determine whether a source is due for its next polling cycle.
|
|
|
|
Checks:
|
|
- If the source has never run, it is due.
|
|
- If the last run failed and we have a next_retry_at in the future, skip.
|
|
- If the last run failed and retry_count exceeds max, skip (needs manual reset).
|
|
- Otherwise, check if enough time has elapsed since the last completed run.
|
|
"""
|
|
# Never run before — always due
|
|
if last_completed_at is None and last_status is None:
|
|
return True
|
|
|
|
# If last run failed, respect backoff
|
|
if last_status == "failed":
|
|
if retry_count >= MAX_RETRY_COUNT:
|
|
return False
|
|
if next_retry_at and now < next_retry_at.replace(tzinfo=None):
|
|
return False
|
|
# Backoff elapsed or no next_retry_at set — allow retry
|
|
return True
|
|
|
|
# If currently running, don't double-schedule
|
|
if last_status == "running":
|
|
return False
|
|
|
|
# Normal cadence check
|
|
if last_completed_at is None:
|
|
return True
|
|
|
|
cadence = get_cadence_for_source(source_type, source_config)
|
|
elapsed = (now - last_completed_at.replace(tzinfo=None)).total_seconds()
|
|
return elapsed >= cadence
|
|
|
|
|
|
def build_job_payload(
|
|
source: Any,
|
|
aliases: list[str],
|
|
now: datetime,
|
|
) -> dict[str, Any]:
|
|
"""Build the ingestion job payload for a source."""
|
|
return {
|
|
"source_id": str(source["source_id"]),
|
|
"company_id": str(source["company_id"]) if source.get("company_id") else None,
|
|
"ticker": source.get("ticker") or "",
|
|
"legal_name": source.get("legal_name") or "",
|
|
"aliases": aliases,
|
|
"source_type": source["source_type"],
|
|
"source_name": source["source_name"],
|
|
"config": _ensure_dict(source["config"]) or {},
|
|
"credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5,
|
|
"scheduled_at": now.isoformat(),
|
|
}
|
|
|
|
|
|
async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool:
|
|
"""Acquire a distributed lock. Returns True if acquired."""
|
|
return await rds.set(lock_key(name), "1", nx=True, ex=ttl)
|
|
|
|
|
|
async def release_lock(rds: aioredis.Redis, name: str) -> None:
|
|
"""Release a distributed lock."""
|
|
await rds.delete(lock_key(name))
|
|
|
|
|
|
async def check_rate_limit(
|
|
rds: aioredis.Redis,
|
|
source_type: str,
|
|
now: datetime,
|
|
max_per_minute: Optional[int] = None,
|
|
) -> bool:
|
|
"""Check whether the source type is within its rate limit window.
|
|
|
|
Enforces two limits:
|
|
1. Per-source-type limit (e.g. market_api: 20/min)
|
|
2. Global Polygon limit across all Polygon-backed types (45/min combined)
|
|
|
|
Returns True if the request is allowed, False if rate-limited.
|
|
"""
|
|
limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30)
|
|
window = now.strftime("%Y%m%d%H%M")
|
|
|
|
# Per-source-type check
|
|
key = rate_limit_key(source_type, window)
|
|
count = await rds.incr(key)
|
|
if count == 1:
|
|
await rds.expire(key, 120)
|
|
if count > limit:
|
|
return False
|
|
|
|
# Global Polygon check for source types that share the Polygon API key
|
|
if source_type in POLYGON_SOURCE_TYPES:
|
|
global_key = rate_limit_key("_polygon_global", window)
|
|
global_count = await rds.incr(global_key)
|
|
if global_count == 1:
|
|
await rds.expire(global_key, 120)
|
|
if global_count > POLYGON_GLOBAL_RATE_LIMIT:
|
|
# Roll back the per-type counter since we won't actually make the call
|
|
await rds.decr(key)
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
|
"""Fetch all active company-specific sources joined with their active companies."""
|
|
return await pool.fetch(
|
|
"""SELECT s.id AS source_id,
|
|
s.company_id,
|
|
s.source_type,
|
|
s.source_name,
|
|
s.config,
|
|
s.credibility_score,
|
|
c.ticker,
|
|
c.legal_name
|
|
FROM sources s
|
|
JOIN companies c ON s.company_id = c.id
|
|
WHERE s.active = TRUE AND c.active = TRUE
|
|
AND s.source_type != 'macro_news'
|
|
ORDER BY s.source_type, c.ticker"""
|
|
)
|
|
|
|
|
|
async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
|
"""Fetch all active macro news sources.
|
|
|
|
Macro sources are not company-specific — they have source_type='macro_news'
|
|
and may have company_id NULL. They are scheduled independently from
|
|
company-specific sources.
|
|
|
|
Requirements: 1.1
|
|
"""
|
|
return await pool.fetch(
|
|
"""SELECT s.id AS source_id,
|
|
s.company_id,
|
|
s.source_type,
|
|
s.source_name,
|
|
s.config,
|
|
s.credibility_score
|
|
FROM sources s
|
|
WHERE s.active = TRUE AND s.source_type = 'macro_news'
|
|
ORDER BY s.source_name"""
|
|
)
|
|
|
|
|
|
async def fetch_global_market_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
|
"""Fetch active market sources that are not company-specific.
|
|
|
|
These are sources like the grouped daily endpoint that fetch data for
|
|
all tickers in a single API call. They have company_id IS NULL and
|
|
source_type = 'market_api'.
|
|
"""
|
|
return await pool.fetch(
|
|
"""SELECT s.id AS source_id,
|
|
s.company_id,
|
|
s.source_type,
|
|
s.source_name,
|
|
s.config,
|
|
s.credibility_score
|
|
FROM sources s
|
|
WHERE s.active = TRUE
|
|
AND s.source_type = 'market_api'
|
|
AND s.company_id IS NULL
|
|
ORDER BY s.source_name"""
|
|
)
|
|
|
|
|
|
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
|
|
"""Fetch all aliases for a company."""
|
|
rows = await pool.fetch(
|
|
"SELECT alias FROM company_aliases WHERE company_id = $1",
|
|
company_id,
|
|
)
|
|
return [r["alias"] for r in rows]
|
|
|
|
|
|
async def fetch_last_run(
|
|
pool: asyncpg.Pool, source_id: str
|
|
) -> Optional[asyncpg.Record]:
|
|
"""Fetch the most recent ingestion run for a source."""
|
|
return await pool.fetchrow(
|
|
"""SELECT status, started_at, completed_at, retry_count, next_retry_at
|
|
FROM ingestion_runs
|
|
WHERE source_id = $1
|
|
ORDER BY started_at DESC
|
|
LIMIT 1""",
|
|
source_id,
|
|
)
|
|
|
|
|
|
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
|
"""One scheduling pass: find due sources and enqueue ingestion jobs.
|
|
|
|
Returns the number of jobs enqueued.
|
|
"""
|
|
now = datetime.utcnow()
|
|
sources = await fetch_active_sources(pool)
|
|
|
|
enqueued = 0
|
|
skipped_rate_limit = 0
|
|
skipped_not_due = 0
|
|
|
|
for src in sources:
|
|
source_id = src["source_id"]
|
|
source_type = src["source_type"]
|
|
source_config = _ensure_dict(src["config"])
|
|
|
|
# Check last run status and timing
|
|
last_run = await fetch_last_run(pool, source_id)
|
|
|
|
last_completed_at = None
|
|
last_status = None
|
|
retry_count = 0
|
|
next_retry_at = None
|
|
|
|
if last_run:
|
|
last_status = last_run["status"]
|
|
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
|
retry_count = last_run["retry_count"] or 0
|
|
next_retry_at = last_run["next_retry_at"]
|
|
|
|
if not is_source_due(
|
|
source_type=source_type,
|
|
source_config=source_config,
|
|
last_completed_at=last_completed_at,
|
|
last_status=last_status,
|
|
retry_count=retry_count,
|
|
next_retry_at=next_retry_at,
|
|
now=now,
|
|
):
|
|
skipped_not_due += 1
|
|
continue
|
|
|
|
# Rate limit check
|
|
if not await check_rate_limit(rds, source_type, now):
|
|
logger.warning(
|
|
"Rate limit hit for %s, skipping %s/%s",
|
|
source_type, src["ticker"], src["source_name"],
|
|
)
|
|
skipped_rate_limit += 1
|
|
continue
|
|
|
|
# Fetch company aliases for downstream entity matching
|
|
aliases = await fetch_aliases_for_company(pool, src["company_id"])
|
|
|
|
job = build_job_payload(src, aliases, now)
|
|
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc]
|
|
enqueued += 1
|
|
|
|
logger.debug(
|
|
"Enqueued %s job for %s (%s)",
|
|
source_type, src["ticker"], src["source_name"],
|
|
)
|
|
|
|
# --- Schedule macro news sources (Requirement 1.1) ---
|
|
macro_sources = await fetch_macro_sources(pool)
|
|
for src in macro_sources:
|
|
source_id = src["source_id"]
|
|
source_type = src["source_type"]
|
|
source_config = _ensure_dict(src["config"])
|
|
|
|
last_run = await fetch_last_run(pool, source_id)
|
|
|
|
last_completed_at = None
|
|
last_status = None
|
|
retry_count = 0
|
|
next_retry_at = None
|
|
|
|
if last_run:
|
|
last_status = last_run["status"]
|
|
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
|
retry_count = last_run["retry_count"] or 0
|
|
next_retry_at = last_run["next_retry_at"]
|
|
|
|
if not is_source_due(
|
|
source_type=source_type,
|
|
source_config=source_config,
|
|
last_completed_at=last_completed_at,
|
|
last_status=last_status,
|
|
retry_count=retry_count,
|
|
next_retry_at=next_retry_at,
|
|
now=now,
|
|
):
|
|
skipped_not_due += 1
|
|
continue
|
|
|
|
if not await check_rate_limit(rds, source_type, now):
|
|
logger.warning(
|
|
"Rate limit hit for macro_news, skipping %s",
|
|
src["source_name"],
|
|
)
|
|
skipped_rate_limit += 1
|
|
continue
|
|
|
|
job = build_job_payload(src, [], now)
|
|
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
|
|
enqueued += 1
|
|
|
|
logger.debug(
|
|
"Enqueued macro_news job for %s", src["source_name"],
|
|
)
|
|
|
|
# --- Schedule global market sources (grouped daily, etc.) ---
|
|
global_market_sources = await fetch_global_market_sources(pool)
|
|
for src in global_market_sources:
|
|
source_id = src["source_id"]
|
|
source_type = src["source_type"]
|
|
source_config = _ensure_dict(src["config"])
|
|
|
|
# Use a longer cadence for grouped daily (once per day = 86400s)
|
|
endpoint = source_config.get("endpoint", "")
|
|
if endpoint == "grouped_daily":
|
|
source_config.setdefault("polling_interval_seconds", 86400)
|
|
|
|
last_run = await fetch_last_run(pool, source_id)
|
|
|
|
last_completed_at = None
|
|
last_status = None
|
|
retry_count = 0
|
|
next_retry_at = None
|
|
|
|
if last_run:
|
|
last_status = last_run["status"]
|
|
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
|
retry_count = last_run["retry_count"] or 0
|
|
next_retry_at = last_run["next_retry_at"]
|
|
|
|
if not is_source_due(
|
|
source_type=source_type,
|
|
source_config=source_config,
|
|
last_completed_at=last_completed_at,
|
|
last_status=last_status,
|
|
retry_count=retry_count,
|
|
next_retry_at=next_retry_at,
|
|
now=now,
|
|
):
|
|
skipped_not_due += 1
|
|
continue
|
|
|
|
if not await check_rate_limit(rds, source_type, now):
|
|
skipped_rate_limit += 1
|
|
continue
|
|
|
|
# Build job with ticker="_MARKET" for global sources
|
|
job = build_job_payload(src, [], now)
|
|
job["ticker"] = "_MARKET"
|
|
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
|
|
enqueued += 1
|
|
|
|
logger.info("Enqueued grouped daily market data job")
|
|
|
|
logger.info(
|
|
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
|
|
enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources) + len(global_market_sources),
|
|
)
|
|
return enqueued
|
|
|
|
|
|
async def main() -> None:
|
|
config = load_config()
|
|
setup_logging("scheduler", level=config.log_level, json_output=config.json_logs)
|
|
|
|
pool = await get_pg_pool(config)
|
|
rds = get_redis(config)
|
|
|
|
logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK)
|
|
recovery_counter = 0
|
|
cleanup_counter = 0
|
|
try:
|
|
while True:
|
|
try:
|
|
if await acquire_lock(rds, "scheduler_cycle", ttl=30):
|
|
try:
|
|
await schedule_cycle(pool, rds)
|
|
# Run stale document recovery every ~20 cycles (~5 minutes)
|
|
recovery_counter += 1
|
|
if recovery_counter >= 20:
|
|
recovery_counter = 0
|
|
await recover_stale_documents(pool, rds)
|
|
# Run signal cleanup periodically (~25 minutes)
|
|
cleanup_counter += 1
|
|
if cleanup_counter >= CLEANUP_CYCLE_INTERVAL:
|
|
cleanup_counter = 0
|
|
await cleanup_all_tables(pool)
|
|
finally:
|
|
await release_lock(rds, "scheduler_cycle")
|
|
except Exception:
|
|
logger.exception("Scheduler cycle error")
|
|
await asyncio.sleep(SCHEDULER_TICK)
|
|
finally:
|
|
await pool.close()
|
|
await rds.close()
|
|
|
|
|
|
# How long a document can sit in "parsed" before we consider it orphaned
|
|
STALE_PARSED_THRESHOLD_MINUTES: int = 30
|
|
|
|
|
|
async def recover_stale_documents(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
|
"""Re-enqueue documents stuck in 'parsed' status for extraction.
|
|
|
|
Documents can get orphaned when Redis loses queue entries (pod restart,
|
|
OOM, etc.). This sweep catches any document that has been in 'parsed'
|
|
status for longer than STALE_PARSED_THRESHOLD_MINUTES and re-enqueues
|
|
it for extraction.
|
|
|
|
Returns the number of documents re-enqueued.
|
|
"""
|
|
rows = await pool.fetch(
|
|
"""SELECT d.id, d.document_type, dcm.ticker
|
|
FROM documents d
|
|
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
|
|
WHERE d.status = 'parsed'
|
|
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM global_events ge WHERE ge.source_document_id = d.id
|
|
)
|
|
ORDER BY d.created_at ASC
|
|
LIMIT 100""",
|
|
STALE_PARSED_THRESHOLD_MINUTES,
|
|
)
|
|
|
|
if not rows:
|
|
return 0
|
|
|
|
enqueued = 0
|
|
doc_ids = []
|
|
for row in rows:
|
|
doc_type = row["document_type"]
|
|
if doc_type == "macro_event":
|
|
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
|
else:
|
|
target = queue_key(QUEUE_EXTRACTION)
|
|
|
|
await rds.rpush(target, json.dumps({
|
|
"document_id": str(row["id"]),
|
|
"ticker": row["ticker"] or "",
|
|
}))
|
|
doc_ids.append(row["id"])
|
|
enqueued += 1
|
|
|
|
# Touch updated_at so these docs won't be re-enqueued until the threshold passes again
|
|
if doc_ids:
|
|
await pool.execute(
|
|
"UPDATE documents SET updated_at = NOW() WHERE id = ANY($1::uuid[])",
|
|
doc_ids,
|
|
)
|
|
|
|
logger.info("Recovered %d stale parsed documents for extraction", enqueued)
|
|
return enqueued
|
|
|
|
|
|
# How often to run competitive signal cleanup (every ~100 cycles = ~25 minutes)
|
|
CLEANUP_CYCLE_INTERVAL: int = 100
|
|
# Keep competitive signals for this many days
|
|
COMPETITIVE_SIGNAL_RETENTION_DAYS: int = 30
|
|
|
|
|
|
async def cleanup_old_signals(pool: asyncpg.Pool) -> int:
|
|
"""Delete competitive signal records older than the retention window.
|
|
|
|
Prevents the competitive_signal_records table from growing unbounded.
|
|
Returns the number of rows deleted.
|
|
"""
|
|
result = await pool.execute(
|
|
"DELETE FROM competitive_signal_records WHERE computed_at < NOW() - INTERVAL '1 day' * $1",
|
|
COMPETITIVE_SIGNAL_RETENTION_DAYS,
|
|
)
|
|
# result is like "DELETE 1234"
|
|
count = int(result.split()[-1]) if result else 0
|
|
if count > 0:
|
|
logger.info("Cleaned up %d old competitive signal records", count)
|
|
return count
|
|
|
|
|
|
async def cleanup_all_tables(pool: asyncpg.Pool) -> None:
|
|
"""Run retention cleanup across all tables that grow unbounded.
|
|
|
|
Called periodically by the scheduler (~every 25 minutes).
|
|
Each table has its own retention window.
|
|
"""
|
|
cleanups = [
|
|
# (table, time_column, retention_days, description)
|
|
("competitive_signal_records", "computed_at", 30, "competitive signals"),
|
|
("ingestion_runs", "started_at", 14, "ingestion runs"),
|
|
("trading_decisions", "created_at", 90, "trading decisions"),
|
|
("risk_evaluations", "evaluated_at", 30, "risk evaluations"),
|
|
("audit_events", "created_at", 90, "audit events"),
|
|
("macro_impact_records", "computed_at", 60, "macro impact records"),
|
|
("recommendation_evidence", "created_at", 60, "recommendation evidence"),
|
|
("recommendations", "generated_at", 30, "old recommendations"),
|
|
("order_events", "created_at", 90, "order events"),
|
|
("model_performance_metrics", "recorded_at", 30, "model metrics"),
|
|
]
|
|
|
|
total = 0
|
|
for table, col, days, desc in cleanups:
|
|
try:
|
|
result = await pool.execute(
|
|
f"DELETE FROM {table} WHERE {col} < NOW() - INTERVAL '1 day' * $1", # noqa: S608
|
|
days,
|
|
)
|
|
count = int(result.split()[-1]) if result else 0
|
|
if count > 0:
|
|
total += count
|
|
logger.info("Cleaned %d %s (>%dd old)", count, desc, days)
|
|
except Exception:
|
|
# Table might not exist or column name might differ — skip
|
|
pass
|
|
|
|
if total > 0:
|
|
logger.info("Total cleanup: %d rows deleted across all tables", total)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|