Files
stonks-oracle/services/scheduler/app.py
T
Celes Renata de35279269 feat: retry failed extractions button on pipeline page
- POST /api/ops/pipeline/retry-failed endpoint resets extraction_failed
  docs to parsed, deletes failed intelligence rows, and re-enqueues
  them (batch of 200)
- Scheduler now auto-retries extraction_failed docs every ~10 minutes
  (100 per cycle, 60-min cooldown per doc)
- Pipeline page shows 'Retry Failed (N)' button when extraction_failed
  count > 0, with pending/success/error states
2026-04-20 08:09:29 +00:00

721 lines
24 KiB
Python

"""Scheduler - triggers ingestion cycles for tracked symbols and sources.
Polls the symbol registry for active companies and their configured sources,
respects per-source polling cadences and backoff windows, coordinates rate
limits across source types, and enqueues ingestion jobs for downstream workers.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
"""
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Any, Optional
import asyncpg
import redis.asyncio as aioredis
from services.shared.config import load_config
from services.shared.db import get_pg_pool, get_redis
from services.shared.logging import setup_logging
from services.shared.redis_keys import (
QUEUE_EXTRACTION,
QUEUE_INGESTION,
QUEUE_MACRO_CLASSIFICATION,
lock_key,
queue_key,
rate_limit_key,
)
logger = logging.getLogger("scheduler")
def _ensure_dict(val: Any) -> Optional[dict]:
"""Coerce a JSONB value (dict or JSON string) to a Python dict."""
if val is None:
return None
if isinstance(val, dict):
return val
if isinstance(val, str):
try:
parsed = json.loads(val)
return parsed if isinstance(parsed, dict) else None
except (json.JSONDecodeError, TypeError):
return None
return None
# Default polling cadences by source class (seconds).
# Individual sources can override via config.polling_interval_seconds.
DEFAULT_CADENCES: dict[str, int] = {
"market_api": 300,
"news_api": 300,
"filings_api": 3600,
"web_scrape": 1800,
"broker": 30,
"macro_news": 600,
}
# Default rate limits per source type (requests per minute)
DEFAULT_RATE_LIMITS: dict[str, int] = {
"market_api": 20,
"news_api": 20,
"filings_api": 10,
"web_scrape": 10,
"broker": 60,
"macro_news": 10,
}
# Global rate limit across all Polygon-backed source types (requests per minute).
# market_api + news_api share a single Polygon API key, so we cap the combined
# throughput to stay safely under the plan limit.
POLYGON_SOURCE_TYPES: set[str] = {"market_api", "news_api"}
POLYGON_GLOBAL_RATE_LIMIT: int = 45
# How long to wait before retrying a failed source (seconds)
DEFAULT_BACKOFF_BASE: int = 60
MAX_BACKOFF: int = 3600
MAX_RETRY_COUNT: int = 10
# Main loop interval (seconds)
SCHEDULER_TICK: int = 15
def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int:
"""Return the polling interval for a source.
Uses the source's config.polling_interval_seconds if set,
otherwise falls back to the default cadence for the source type.
"""
if config and "polling_interval_seconds" in config:
try:
return max(10, int(config["polling_interval_seconds"]))
except (ValueError, TypeError):
pass
return DEFAULT_CADENCES.get(source_type, 600)
def compute_backoff(retry_count: int) -> int:
"""Exponential backoff with a cap. Returns seconds to wait."""
delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8))
return min(delay, MAX_BACKOFF)
def is_source_due(
source_type: str,
source_config: Optional[dict[str, Any]],
last_completed_at: Optional[datetime],
last_status: Optional[str],
retry_count: int,
next_retry_at: Optional[datetime],
now: datetime,
) -> bool:
"""Determine whether a source is due for its next polling cycle.
Checks:
- If the source has never run, it is due.
- If the last run failed and we have a next_retry_at in the future, skip.
- If the last run failed and retry_count exceeds max, skip (needs manual reset).
- Otherwise, check if enough time has elapsed since the last completed run.
"""
# Never run before — always due
if last_completed_at is None and last_status is None:
return True
# If last run failed, respect backoff
if last_status == "failed":
if retry_count >= MAX_RETRY_COUNT:
return False
if next_retry_at:
# Normalize tz-awareness to match 'now'
if now.tzinfo is not None and next_retry_at.tzinfo is None:
nra = next_retry_at.replace(tzinfo=timezone.utc)
elif now.tzinfo is None and next_retry_at.tzinfo is not None:
nra = next_retry_at.replace(tzinfo=None)
else:
nra = next_retry_at
if now < nra:
return False
# Backoff elapsed or no next_retry_at set — allow retry
return True
# If currently running, don't double-schedule
if last_status == "running":
return False
# Normal cadence check
if last_completed_at is None:
return True
cadence = get_cadence_for_source(source_type, source_config)
# Ensure both datetimes have matching tz-awareness for subtraction
if now.tzinfo is not None and last_completed_at.tzinfo is None:
last_completed_at = last_completed_at.replace(tzinfo=timezone.utc)
elif now.tzinfo is None and last_completed_at.tzinfo is not None:
last_completed_at = last_completed_at.replace(tzinfo=None)
elapsed = (now - last_completed_at).total_seconds()
return elapsed >= cadence
def build_job_payload(
source: Any,
aliases: list[str],
now: datetime,
) -> dict[str, Any]:
"""Build the ingestion job payload for a source."""
return {
"source_id": str(source["source_id"]),
"company_id": str(source["company_id"]) if source.get("company_id") else None,
"ticker": source.get("ticker") or "",
"legal_name": source.get("legal_name") or "",
"aliases": aliases,
"source_type": source["source_type"],
"source_name": source["source_name"],
"config": _ensure_dict(source["config"]) or {},
"credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5,
"scheduled_at": now.isoformat(),
}
async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool:
"""Acquire a distributed lock. Returns True if acquired."""
return await rds.set(lock_key(name), "1", nx=True, ex=ttl)
async def release_lock(rds: aioredis.Redis, name: str) -> None:
"""Release a distributed lock."""
await rds.delete(lock_key(name))
async def check_rate_limit(
rds: aioredis.Redis,
source_type: str,
now: datetime,
max_per_minute: Optional[int] = None,
) -> bool:
"""Check whether the source type is within its rate limit window.
Enforces two limits:
1. Per-source-type limit (e.g. market_api: 20/min)
2. Global Polygon limit across all Polygon-backed types (45/min combined)
Returns True if the request is allowed, False if rate-limited.
"""
limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30)
window = now.strftime("%Y%m%d%H%M")
# Per-source-type check
key = rate_limit_key(source_type, window)
count = await rds.incr(key)
if count == 1:
await rds.expire(key, 120)
if count > limit:
return False
# Global Polygon check for source types that share the Polygon API key
if source_type in POLYGON_SOURCE_TYPES:
global_key = rate_limit_key("_polygon_global", window)
global_count = await rds.incr(global_key)
if global_count == 1:
await rds.expire(global_key, 120)
if global_count > POLYGON_GLOBAL_RATE_LIMIT:
# Roll back the per-type counter since we won't actually make the call
await rds.decr(key)
return False
return True
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch all active company-specific sources joined with their active companies."""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
s.source_type,
s.source_name,
s.config,
s.credibility_score,
c.ticker,
c.legal_name
FROM sources s
JOIN companies c ON s.company_id = c.id
WHERE s.active = TRUE AND c.active = TRUE
AND s.source_type != 'macro_news'
ORDER BY s.source_type, c.ticker"""
)
async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch all active macro news sources.
Macro sources are not company-specific — they have source_type='macro_news'
and may have company_id NULL. They are scheduled independently from
company-specific sources.
Requirements: 1.1
"""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
s.source_type,
s.source_name,
s.config,
s.credibility_score
FROM sources s
WHERE s.active = TRUE AND s.source_type = 'macro_news'
ORDER BY s.source_name"""
)
async def fetch_global_market_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch active market sources that are not company-specific.
These are sources like the grouped daily endpoint that fetch data for
all tickers in a single API call. They have company_id IS NULL and
source_type = 'market_api'.
"""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
s.source_type,
s.source_name,
s.config,
s.credibility_score
FROM sources s
WHERE s.active = TRUE
AND s.source_type = 'market_api'
AND s.company_id IS NULL
ORDER BY s.source_name"""
)
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
"""Fetch all aliases for a company."""
rows = await pool.fetch(
"SELECT alias FROM company_aliases WHERE company_id = $1",
company_id,
)
return [r["alias"] for r in rows]
async def fetch_last_run(
pool: asyncpg.Pool, source_id: str
) -> Optional[asyncpg.Record]:
"""Fetch the most recent ingestion run for a source."""
return await pool.fetchrow(
"""SELECT status, started_at, completed_at, retry_count, next_retry_at
FROM ingestion_runs
WHERE source_id = $1
ORDER BY started_at DESC
LIMIT 1""",
source_id,
)
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
"""One scheduling pass: find due sources and enqueue ingestion jobs.
Returns the number of jobs enqueued.
"""
now = datetime.now(tz=timezone.utc)
sources = await fetch_active_sources(pool)
enqueued = 0
skipped_rate_limit = 0
skipped_not_due = 0
for src in sources:
source_id = src["source_id"]
source_type = src["source_type"]
source_config = _ensure_dict(src["config"])
# Check last run status and timing
last_run = await fetch_last_run(pool, source_id)
last_completed_at = None
last_status = None
retry_count = 0
next_retry_at = None
if last_run:
last_status = last_run["status"]
last_completed_at = last_run["completed_at"] or last_run["started_at"]
retry_count = last_run["retry_count"] or 0
next_retry_at = last_run["next_retry_at"]
if not is_source_due(
source_type=source_type,
source_config=source_config,
last_completed_at=last_completed_at,
last_status=last_status,
retry_count=retry_count,
next_retry_at=next_retry_at,
now=now,
):
skipped_not_due += 1
continue
# Rate limit check
if not await check_rate_limit(rds, source_type, now):
logger.warning(
"Rate limit hit for %s, skipping %s/%s",
source_type, src["ticker"], src["source_name"],
)
skipped_rate_limit += 1
continue
# Fetch company aliases for downstream entity matching
aliases = await fetch_aliases_for_company(pool, src["company_id"])
job = build_job_payload(src, aliases, now)
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc]
enqueued += 1
logger.debug(
"Enqueued %s job for %s (%s)",
source_type, src["ticker"], src["source_name"],
)
# --- Schedule macro news sources (Requirement 1.1) ---
macro_sources = await fetch_macro_sources(pool)
for src in macro_sources:
source_id = src["source_id"]
source_type = src["source_type"]
source_config = _ensure_dict(src["config"])
last_run = await fetch_last_run(pool, source_id)
last_completed_at = None
last_status = None
retry_count = 0
next_retry_at = None
if last_run:
last_status = last_run["status"]
last_completed_at = last_run["completed_at"] or last_run["started_at"]
retry_count = last_run["retry_count"] or 0
next_retry_at = last_run["next_retry_at"]
if not is_source_due(
source_type=source_type,
source_config=source_config,
last_completed_at=last_completed_at,
last_status=last_status,
retry_count=retry_count,
next_retry_at=next_retry_at,
now=now,
):
skipped_not_due += 1
continue
if not await check_rate_limit(rds, source_type, now):
logger.warning(
"Rate limit hit for macro_news, skipping %s",
src["source_name"],
)
skipped_rate_limit += 1
continue
job = build_job_payload(src, [], now)
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
enqueued += 1
logger.debug(
"Enqueued macro_news job for %s", src["source_name"],
)
# --- Schedule global market sources (grouped daily, etc.) ---
global_market_sources = await fetch_global_market_sources(pool)
for src in global_market_sources:
source_id = src["source_id"]
source_type = src["source_type"]
source_config = _ensure_dict(src["config"])
# Use a longer cadence for grouped daily (once per day = 86400s)
endpoint = source_config.get("endpoint", "")
if endpoint == "grouped_daily":
source_config.setdefault("polling_interval_seconds", 86400)
last_run = await fetch_last_run(pool, source_id)
last_completed_at = None
last_status = None
retry_count = 0
next_retry_at = None
if last_run:
last_status = last_run["status"]
last_completed_at = last_run["completed_at"] or last_run["started_at"]
retry_count = last_run["retry_count"] or 0
next_retry_at = last_run["next_retry_at"]
if not is_source_due(
source_type=source_type,
source_config=source_config,
last_completed_at=last_completed_at,
last_status=last_status,
retry_count=retry_count,
next_retry_at=next_retry_at,
now=now,
):
skipped_not_due += 1
continue
if not await check_rate_limit(rds, source_type, now):
skipped_rate_limit += 1
continue
# Build job with ticker="_MARKET" for global sources
job = build_job_payload(src, [], now)
if endpoint == "intraday_bars":
# Expand intraday source into per-ticker jobs for all active companies
tickers = await pool.fetch(
"SELECT ticker FROM companies WHERE active = TRUE"
)
for t_row in tickers:
ticker_job = dict(job)
ticker_job["ticker"] = t_row["ticker"]
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(ticker_job))
enqueued += 1
logger.info("Enqueued %d intraday bar jobs", len(tickers))
else:
job["ticker"] = "_MARKET"
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
enqueued += 1
logger.info("Enqueued grouped daily market data job")
logger.info(
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources) + len(global_market_sources),
)
return enqueued
async def main() -> None:
config = load_config()
setup_logging("scheduler", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
rds = get_redis(config)
logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK)
recovery_counter = 0
retry_counter = 0
cleanup_counter = 0
try:
while True:
try:
if await acquire_lock(rds, "scheduler_cycle", ttl=30):
try:
await schedule_cycle(pool, rds)
# Run stale document recovery every ~20 cycles (~5 minutes)
recovery_counter += 1
if recovery_counter >= 20:
recovery_counter = 0
await recover_stale_documents(pool, rds)
# Retry extraction failures every ~40 cycles (~10 minutes)
retry_counter += 1
if retry_counter >= 40:
retry_counter = 0
await retry_failed_extractions(pool, rds)
# Run signal cleanup periodically (~25 minutes)
cleanup_counter += 1
if cleanup_counter >= CLEANUP_CYCLE_INTERVAL:
cleanup_counter = 0
await cleanup_all_tables(pool)
finally:
await release_lock(rds, "scheduler_cycle")
except Exception:
logger.exception("Scheduler cycle error")
await asyncio.sleep(SCHEDULER_TICK)
finally:
await pool.close()
await rds.close()
# How long a document can sit in "parsed" before we consider it orphaned
STALE_PARSED_THRESHOLD_MINUTES: int = 30
# How long after an extraction failure before we retry
EXTRACTION_FAILED_RETRY_MINUTES: int = 60
async def recover_stale_documents(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
"""Re-enqueue documents stuck in 'parsed' status for extraction.
Documents can get orphaned when Redis loses queue entries (pod restart,
OOM, etc.). This sweep catches any document that has been in 'parsed'
status for longer than STALE_PARSED_THRESHOLD_MINUTES and re-enqueues
it for extraction.
Returns the number of documents re-enqueued.
"""
rows = await pool.fetch(
"""SELECT d.id, d.document_type, dcm.ticker
FROM documents d
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
WHERE d.status = 'parsed'
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
AND NOT EXISTS (
SELECT 1 FROM global_events ge WHERE ge.source_document_id = d.id
)
ORDER BY d.created_at ASC
LIMIT 100""",
STALE_PARSED_THRESHOLD_MINUTES,
)
if not rows:
return 0
enqueued = 0
doc_ids = []
for row in rows:
doc_type = row["document_type"]
if doc_type == "macro_event":
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
else:
target = queue_key(QUEUE_EXTRACTION)
await rds.rpush(target, json.dumps({
"document_id": str(row["id"]),
"ticker": row["ticker"] or "",
}))
doc_ids.append(row["id"])
enqueued += 1
# Touch updated_at so these docs won't be re-enqueued until the threshold passes again
if doc_ids:
await pool.execute(
"UPDATE documents SET updated_at = NOW() WHERE id = ANY($1::uuid[])",
doc_ids,
)
logger.info("Recovered %d stale parsed documents for extraction", enqueued)
return enqueued
async def retry_failed_extractions(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
"""Re-enqueue documents stuck in 'extraction_failed' for another attempt.
Resets status to 'parsed', deletes the failed intelligence row so the
extractor treats them as fresh, and pushes them onto the extraction queue.
Only retries documents whose last attempt was at least
EXTRACTION_FAILED_RETRY_MINUTES ago to avoid tight retry loops.
Returns the number of documents re-enqueued.
"""
rows = await pool.fetch(
"""SELECT d.id, d.document_type, dcm.ticker
FROM documents d
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
WHERE d.status = 'extraction_failed'
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
ORDER BY d.updated_at ASC
LIMIT 100""",
EXTRACTION_FAILED_RETRY_MINUTES,
)
if not rows:
return 0
enqueued = 0
doc_ids = []
for row in rows:
doc_type = row["document_type"]
if doc_type == "macro_event":
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
else:
target = queue_key(QUEUE_EXTRACTION)
await rds.rpush(target, json.dumps({
"document_id": str(row["id"]),
"ticker": row["ticker"] or "",
}))
doc_ids.append(row["id"])
enqueued += 1
if doc_ids:
# Delete failed intelligence rows so extractor starts fresh
await pool.execute(
"""DELETE FROM document_intelligence
WHERE document_id = ANY($1::uuid[])
AND validation_status = 'failed'""",
doc_ids,
)
# Reset status to 'parsed' and touch updated_at
await pool.execute(
"""UPDATE documents
SET status = 'parsed', updated_at = NOW()
WHERE id = ANY($1::uuid[])""",
doc_ids,
)
logger.info("Retried %d extraction-failed documents", enqueued)
return enqueued
# How often to run competitive signal cleanup (every ~100 cycles = ~25 minutes)
CLEANUP_CYCLE_INTERVAL: int = 100
# Keep competitive signals for this many days
COMPETITIVE_SIGNAL_RETENTION_DAYS: int = 30
async def cleanup_old_signals(pool: asyncpg.Pool) -> int:
"""Delete competitive signal records older than the retention window.
Prevents the competitive_signal_records table from growing unbounded.
Returns the number of rows deleted.
"""
result = await pool.execute(
"DELETE FROM competitive_signal_records WHERE computed_at < NOW() - INTERVAL '1 day' * $1",
COMPETITIVE_SIGNAL_RETENTION_DAYS,
)
# result is like "DELETE 1234"
count = int(result.split()[-1]) if result else 0
if count > 0:
logger.info("Cleaned up %d old competitive signal records", count)
return count
async def cleanup_all_tables(pool: asyncpg.Pool) -> None:
"""Run retention cleanup across all tables that grow unbounded.
Called periodically by the scheduler (~every 25 minutes).
Each table has its own retention window.
"""
cleanups = [
# (table, time_column, retention_days, description)
("competitive_signal_records", "computed_at", 30, "competitive signals"),
("ingestion_runs", "started_at", 14, "ingestion runs"),
("trading_decisions", "created_at", 90, "trading decisions"),
("risk_evaluations", "evaluated_at", 30, "risk evaluations"),
("audit_events", "created_at", 90, "audit events"),
("macro_impact_records", "computed_at", 60, "macro impact records"),
("recommendation_evidence", "created_at", 60, "recommendation evidence"),
("recommendations", "generated_at", 30, "old recommendations"),
("order_events", "created_at", 90, "order events"),
("model_performance_metrics", "recorded_at", 30, "model metrics"),
]
total = 0
for table, col, days, desc in cleanups:
try:
result = await pool.execute(
f"DELETE FROM {table} WHERE {col} < NOW() - INTERVAL '1 day' * $1", # noqa: S608
days,
)
count = int(result.split()[-1]) if result else 0
if count > 0:
total += count
logger.info("Cleaned %d %s (>%dd old)", count, desc, days)
except Exception:
# Table might not exist or column name might differ — skip
pass
if total > 0:
logger.info("Total cleanup: %d rows deleted across all tables", total)
if __name__ == "__main__":
asyncio.run(main())