Files
stonks-oracle/services/scheduler/app.py
T
Celes Renata bc077bfcc8
ci/woodpecker/push/test Pipeline was successful
ci/woodpecker/push/build-2 Pipeline was successful
ci/woodpecker/push/build-3 Pipeline was successful
ci/woodpecker/push/build-1 Pipeline was successful
ci/woodpecker/push/finalize Pipeline was successful
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
feat: trading feedback engine — periodic performance reports with AI summarization
- Migration 038: trading_reports table + report-summarizer agent seed
- 6 reporting modules: models, collector, sections, validator, summarizer, generator
- API endpoints: GET /api/reports (paginated, filterable), GET /api/reports/{id}
- Frontend hooks: useReports, useReport with TanStack Query
- Scheduler: daily (after 16:30 ET) and weekly (Saturday) report triggers
- Redis queue consumer for async report generation with retry/dedup
- 5 property-based tests (chunking, serialization, validation, accuracy, deltas)
- 109 unit/integration tests across all modules
- 6 frontend hook tests with MSW mocks
2026-05-01 22:13:09 +00:00

973 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Scheduler - triggers ingestion cycles for tracked symbols and sources.
Polls the symbol registry for active companies and their configured sources,
respects per-source polling cadences and backoff windows, coordinates rate
limits across source types, and enqueues ingestion jobs for downstream workers.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
"""
import asyncio
import json
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import Any, Optional
from zoneinfo import ZoneInfo
import asyncpg
import redis.asyncio as aioredis
from services.shared.config import load_config
from services.shared.db import get_pg_pool, get_redis
from services.shared.logging import setup_logging
from services.shared.redis_keys import (
PIPELINE_ENABLED_KEY,
QUEUE_AGGREGATION,
QUEUE_EXTRACTION,
QUEUE_INGESTION,
QUEUE_MACRO_CLASSIFICATION,
QUEUE_PREFIX,
QUEUE_REPORT_GENERATION,
lock_key,
queue_key,
rate_limit_key,
)
logger = logging.getLogger("scheduler")
def _ensure_dict(val: Any) -> Optional[dict]:
"""Coerce a JSONB value (dict or JSON string) to a Python dict."""
if val is None:
return None
if isinstance(val, dict):
return val
if isinstance(val, str):
try:
parsed = json.loads(val)
return parsed if isinstance(parsed, dict) else None
except (json.JSONDecodeError, TypeError):
return None
return None
# Default polling cadences by source class (seconds).
# Individual sources can override via config.polling_interval_seconds.
DEFAULT_CADENCES: dict[str, int] = {
"market_api": 300,
"news_api": 300,
"filings_api": 3600,
"web_scrape": 1800,
"broker": 30,
"macro_news": 600,
}
# Default rate limits per source type (requests per minute)
DEFAULT_RATE_LIMITS: dict[str, int] = {
"market_api": 20,
"news_api": 20,
"filings_api": 10,
"web_scrape": 10,
"broker": 60,
"macro_news": 10,
}
# Global rate limit across all Polygon-backed source types (requests per minute).
# market_api + news_api share a single Polygon API key, so we cap the combined
# throughput to stay safely under the plan limit.
POLYGON_SOURCE_TYPES: set[str] = {"market_api", "news_api"}
POLYGON_GLOBAL_RATE_LIMIT: int = 45
# How long to wait before retrying a failed source (seconds)
DEFAULT_BACKOFF_BASE: int = 60
MAX_BACKOFF: int = 3600
MAX_RETRY_COUNT: int = 10
# Main loop interval (seconds)
SCHEDULER_TICK: int = 15
# Periodic aggregation: re-aggregate all tickers every N cycles during market hours
# 15s tick × 60 cycles = 15 minutes
AGGREGATION_CYCLE_INTERVAL: int = 60
def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int:
"""Return the polling interval for a source.
Uses the source's config.polling_interval_seconds if set,
otherwise falls back to the default cadence for the source type.
"""
if config and "polling_interval_seconds" in config:
try:
return max(10, int(config["polling_interval_seconds"]))
except (ValueError, TypeError):
pass
return DEFAULT_CADENCES.get(source_type, 600)
def compute_backoff(retry_count: int) -> int:
"""Exponential backoff with a cap. Returns seconds to wait."""
delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8))
return min(delay, MAX_BACKOFF)
def is_source_due(
source_type: str,
source_config: Optional[dict[str, Any]],
last_completed_at: Optional[datetime],
last_status: Optional[str],
retry_count: int,
next_retry_at: Optional[datetime],
now: datetime,
) -> bool:
"""Determine whether a source is due for its next polling cycle.
Checks:
- If the source has never run, it is due.
- If the last run failed and we have a next_retry_at in the future, skip.
- If the last run failed and retry_count exceeds max, skip (needs manual reset).
- Otherwise, check if enough time has elapsed since the last completed run.
"""
# Never run before — always due
if last_completed_at is None and last_status is None:
return True
# If last run failed, respect backoff
if last_status == "failed":
if retry_count >= MAX_RETRY_COUNT:
return False
if next_retry_at:
# Normalize tz-awareness to match 'now'
if now.tzinfo is not None and next_retry_at.tzinfo is None:
nra = next_retry_at.replace(tzinfo=timezone.utc)
elif now.tzinfo is None and next_retry_at.tzinfo is not None:
nra = next_retry_at.replace(tzinfo=None)
else:
nra = next_retry_at
if now < nra:
return False
# Backoff elapsed or no next_retry_at set — allow retry
return True
# If currently running, don't double-schedule
if last_status == "running":
return False
# Normal cadence check
if last_completed_at is None:
return True
cadence = get_cadence_for_source(source_type, source_config)
# Ensure both datetimes have matching tz-awareness for subtraction
if now.tzinfo is not None and last_completed_at.tzinfo is None:
last_completed_at = last_completed_at.replace(tzinfo=timezone.utc)
elif now.tzinfo is None and last_completed_at.tzinfo is not None:
last_completed_at = last_completed_at.replace(tzinfo=None)
elapsed = (now - last_completed_at).total_seconds()
return elapsed >= cadence
def build_job_payload(
source: Any,
aliases: list[str],
now: datetime,
) -> dict[str, Any]:
"""Build the ingestion job payload for a source."""
return {
"source_id": str(source["source_id"]),
"company_id": str(source["company_id"]) if source.get("company_id") else None,
"ticker": source.get("ticker") or "",
"legal_name": source.get("legal_name") or "",
"aliases": aliases,
"source_type": source["source_type"],
"source_name": source["source_name"],
"config": _ensure_dict(source["config"]) or {},
"credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5,
"scheduled_at": now.isoformat(),
}
async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool:
"""Acquire a distributed lock. Returns True if acquired."""
return await rds.set(lock_key(name), "1", nx=True, ex=ttl)
async def release_lock(rds: aioredis.Redis, name: str) -> None:
"""Release a distributed lock."""
await rds.delete(lock_key(name))
async def check_rate_limit(
rds: aioredis.Redis,
source_type: str,
now: datetime,
max_per_minute: Optional[int] = None,
) -> bool:
"""Check whether the source type is within its rate limit window.
Enforces two limits:
1. Per-source-type limit (e.g. market_api: 20/min)
2. Global Polygon limit across all Polygon-backed types (45/min combined)
Returns True if the request is allowed, False if rate-limited.
"""
limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30)
window = now.strftime("%Y%m%d%H%M")
# Per-source-type check
key = rate_limit_key(source_type, window)
count = await rds.incr(key)
if count == 1:
await rds.expire(key, 120)
if count > limit:
return False
# Global Polygon check for source types that share the Polygon API key
if source_type in POLYGON_SOURCE_TYPES:
global_key = rate_limit_key("_polygon_global", window)
global_count = await rds.incr(global_key)
if global_count == 1:
await rds.expire(global_key, 120)
if global_count > POLYGON_GLOBAL_RATE_LIMIT:
# Roll back the per-type counter since we won't actually make the call
await rds.decr(key)
return False
return True
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch all active company-specific sources joined with their active companies."""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
s.source_type,
s.source_name,
s.config,
s.credibility_score,
c.ticker,
c.legal_name
FROM sources s
JOIN companies c ON s.company_id = c.id
WHERE s.active = TRUE AND c.active = TRUE
AND s.source_type != 'macro_news'
ORDER BY s.source_type, c.ticker"""
)
async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch all active macro news sources.
Macro sources are not company-specific — they have source_type='macro_news'
and may have company_id NULL. They are scheduled independently from
company-specific sources.
Requirements: 1.1
"""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
s.source_type,
s.source_name,
s.config,
s.credibility_score
FROM sources s
WHERE s.active = TRUE AND s.source_type = 'macro_news'
ORDER BY s.source_name"""
)
async def fetch_global_market_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch active market sources that are not company-specific.
These are sources like the grouped daily endpoint that fetch data for
all tickers in a single API call. They have company_id IS NULL and
source_type = 'market_api'.
"""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
s.source_type,
s.source_name,
s.config,
s.credibility_score
FROM sources s
WHERE s.active = TRUE
AND s.source_type = 'market_api'
AND s.company_id IS NULL
ORDER BY s.source_name"""
)
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
"""Fetch all aliases for a company."""
rows = await pool.fetch(
"SELECT alias FROM company_aliases WHERE company_id = $1",
company_id,
)
return [r["alias"] for r in rows]
async def fetch_last_run(
pool: asyncpg.Pool, source_id: str
) -> Optional[asyncpg.Record]:
"""Fetch the most recent ingestion run for a source."""
return await pool.fetchrow(
"""SELECT status, started_at, completed_at, retry_count, next_retry_at
FROM ingestion_runs
WHERE source_id = $1
ORDER BY started_at DESC
LIMIT 1""",
source_id,
)
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
"""One scheduling pass: find due sources and enqueue ingestion jobs.
Returns the number of jobs enqueued.
"""
now = datetime.now(tz=timezone.utc)
sources = await fetch_active_sources(pool)
enqueued = 0
skipped_rate_limit = 0
skipped_not_due = 0
for src in sources:
source_id = src["source_id"]
source_type = src["source_type"]
source_config = _ensure_dict(src["config"])
# Check last run status and timing
last_run = await fetch_last_run(pool, source_id)
last_completed_at = None
last_status = None
retry_count = 0
next_retry_at = None
if last_run:
last_status = last_run["status"]
last_completed_at = last_run["completed_at"] or last_run["started_at"]
retry_count = last_run["retry_count"] or 0
next_retry_at = last_run["next_retry_at"]
if not is_source_due(
source_type=source_type,
source_config=source_config,
last_completed_at=last_completed_at,
last_status=last_status,
retry_count=retry_count,
next_retry_at=next_retry_at,
now=now,
):
skipped_not_due += 1
continue
# Rate limit check
if not await check_rate_limit(rds, source_type, now):
logger.warning(
"Rate limit hit for %s, skipping %s/%s",
source_type, src["ticker"], src["source_name"],
)
skipped_rate_limit += 1
continue
# Fetch company aliases for downstream entity matching
aliases = await fetch_aliases_for_company(pool, src["company_id"])
job = build_job_payload(src, aliases, now)
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc]
enqueued += 1
logger.debug(
"Enqueued %s job for %s (%s)",
source_type, src["ticker"], src["source_name"],
)
# --- Schedule macro news sources (Requirement 1.1) ---
macro_sources = await fetch_macro_sources(pool)
for src in macro_sources:
source_id = src["source_id"]
source_type = src["source_type"]
source_config = _ensure_dict(src["config"])
last_run = await fetch_last_run(pool, source_id)
last_completed_at = None
last_status = None
retry_count = 0
next_retry_at = None
if last_run:
last_status = last_run["status"]
last_completed_at = last_run["completed_at"] or last_run["started_at"]
retry_count = last_run["retry_count"] or 0
next_retry_at = last_run["next_retry_at"]
if not is_source_due(
source_type=source_type,
source_config=source_config,
last_completed_at=last_completed_at,
last_status=last_status,
retry_count=retry_count,
next_retry_at=next_retry_at,
now=now,
):
skipped_not_due += 1
continue
if not await check_rate_limit(rds, source_type, now):
logger.warning(
"Rate limit hit for macro_news, skipping %s",
src["source_name"],
)
skipped_rate_limit += 1
continue
job = build_job_payload(src, [], now)
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
enqueued += 1
logger.debug(
"Enqueued macro_news job for %s", src["source_name"],
)
# --- Schedule global market sources (grouped daily, etc.) ---
global_market_sources = await fetch_global_market_sources(pool)
for src in global_market_sources:
source_id = src["source_id"]
source_type = src["source_type"]
source_config = _ensure_dict(src["config"])
# Use a longer cadence for grouped daily (once per day = 86400s)
endpoint = source_config.get("endpoint", "")
if endpoint == "grouped_daily":
source_config.setdefault("polling_interval_seconds", 86400)
last_run = await fetch_last_run(pool, source_id)
last_completed_at = None
last_status = None
retry_count = 0
next_retry_at = None
if last_run:
last_status = last_run["status"]
last_completed_at = last_run["completed_at"] or last_run["started_at"]
retry_count = last_run["retry_count"] or 0
next_retry_at = last_run["next_retry_at"]
if not is_source_due(
source_type=source_type,
source_config=source_config,
last_completed_at=last_completed_at,
last_status=last_status,
retry_count=retry_count,
next_retry_at=next_retry_at,
now=now,
):
skipped_not_due += 1
continue
if not await check_rate_limit(rds, source_type, now):
skipped_rate_limit += 1
continue
# Build job with ticker="_MARKET" for global sources
job = build_job_payload(src, [], now)
if endpoint == "intraday_bars":
# Expand intraday source into per-ticker jobs for all active companies
tickers = await pool.fetch(
"SELECT ticker FROM companies WHERE active = TRUE"
)
for t_row in tickers:
ticker_job = dict(job)
ticker_job["ticker"] = t_row["ticker"]
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(ticker_job))
enqueued += 1
logger.info("Enqueued %d intraday bar jobs", len(tickers))
else:
job["ticker"] = "_MARKET"
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
enqueued += 1
logger.info("Enqueued grouped daily market data job")
logger.info(
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources) + len(global_market_sources),
)
return enqueued
# ---------------------------------------------------------------------------
# Report generation: queue consumer + scheduled triggers
# Requirements: 6.1, 6.2, 6.3, 6.4, 6.5
# ---------------------------------------------------------------------------
# Eastern Time zone for market-close checks
_ET = ZoneInfo("America/New_York")
# How often to check the report generation queue (every N cycles)
# 15s tick × 4 cycles = ~1 minute
REPORT_CONSUMER_CYCLE_INTERVAL: int = 4
# How often to check report scheduling triggers (every N cycles)
# 15s tick × 20 cycles = ~5 minutes
REPORT_SCHEDULE_CYCLE_INTERVAL: int = 20
# Redis key prefix for report schedule dedup markers
_REPORT_DEDUPE_PREFIX = f"{QUEUE_PREFIX}:report_dedupe"
_REPORT_DEDUPE_TTL = 86400 # 24 hours — prevents re-enqueuing same report within a day
def _report_dedupe_key(report_type: str, period_start: str, period_end: str) -> str:
"""Build a Redis key for deduplicating report schedule triggers."""
return f"{_REPORT_DEDUPE_PREFIX}:{report_type}:{period_start}:{period_end}"
async def consume_report_generation_jobs(
pool: asyncpg.Pool,
rds: aioredis.Redis,
) -> int:
"""Pop and process jobs from the report generation queue.
Pops up to 5 jobs per invocation to avoid blocking the scheduler loop.
Each job is deserialized and handed to process_report_job from the
reporting generator module.
Returns the number of jobs processed.
"""
from services.reporting.generator import process_report_job
report_queue = queue_key(QUEUE_REPORT_GENERATION)
processed = 0
for _ in range(5):
raw = await rds.lpop(report_queue)
if raw is None:
break
try:
job = json.loads(raw)
except (json.JSONDecodeError, TypeError):
logger.error("Invalid report generation job payload: %s", raw)
continue
logger.info(
"Processing report generation job: type=%s period=%s to %s",
job.get("report_type"),
job.get("period_start"),
job.get("period_end"),
)
try:
await process_report_job(pool, job)
processed += 1
except Exception:
logger.exception(
"Failed to process report generation job: %s", job,
)
if processed > 0:
logger.info("Processed %d report generation jobs", processed)
return processed
async def maybe_enqueue_daily_report(
rds: aioredis.Redis,
now_et: datetime,
) -> bool:
"""Enqueue a daily report job if it's after 16:30 ET on a weekday.
Uses a Redis dedupe key to avoid re-enqueuing the same daily report.
Returns True if a job was enqueued, False otherwise.
"""
# Only on weekdays (Mon=0 .. Fri=4)
if now_et.weekday() > 4:
return False
# Only after 16:30 ET
if now_et.hour < 16 or (now_et.hour == 16 and now_et.minute < 30):
return False
today = now_et.date()
period_start = today.isoformat()
period_end = today.isoformat()
dedupe = _report_dedupe_key("daily", period_start, period_end)
created = await rds.set(dedupe, "1", nx=True, ex=_REPORT_DEDUPE_TTL)
if not created:
return False
job = json.dumps({
"report_type": "daily",
"period_start": period_start,
"period_end": period_end,
})
await rds.rpush(queue_key(QUEUE_REPORT_GENERATION), job)
logger.info("Enqueued daily report for %s", period_start)
return True
async def maybe_enqueue_weekly_report(
rds: aioredis.Redis,
now_et: datetime,
) -> bool:
"""Enqueue a weekly report job on Saturday.
Covers the previous Monday through Friday.
Uses a Redis dedupe key to avoid re-enqueuing the same weekly report.
Returns True if a job was enqueued, False otherwise.
"""
# Only on Saturday (weekday() == 5)
if now_et.weekday() != 5:
return False
today = now_et.date()
# Previous Monday = today - 5 days, previous Friday = today - 1 day
period_start = (today - timedelta(days=5)).isoformat()
period_end = (today - timedelta(days=1)).isoformat()
dedupe = _report_dedupe_key("weekly", period_start, period_end)
created = await rds.set(dedupe, "1", nx=True, ex=_REPORT_DEDUPE_TTL)
if not created:
return False
job = json.dumps({
"report_type": "weekly",
"period_start": period_start,
"period_end": period_end,
})
await rds.rpush(queue_key(QUEUE_REPORT_GENERATION), job)
logger.info(
"Enqueued weekly report for %s to %s", period_start, period_end,
)
return True
async def check_report_schedule(rds: aioredis.Redis) -> None:
"""Check if daily or weekly report triggers should fire.
Called periodically from the main loop. Uses Eastern Time to determine
market close (16:30 ET) and day of week.
"""
now_et = datetime.now(tz=_ET)
await maybe_enqueue_daily_report(rds, now_et)
await maybe_enqueue_weekly_report(rds, now_et)
async def enqueue_periodic_aggregation(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
"""Enqueue aggregation jobs for all active tickers.
Runs periodically to ensure trend data stays fresh even when no new
documents are being ingested. During market hours this runs every ~15
minutes; outside market hours it runs every ~60 minutes (for backtesting
data continuity).
"""
# Fetch all active tickers
rows = await pool.fetch(
"SELECT ticker FROM companies WHERE active = TRUE ORDER BY ticker"
)
if not rows:
return 0
agg_queue = queue_key(QUEUE_AGGREGATION)
count = 0
for row in rows:
await rds.rpush(agg_queue, json.dumps({"ticker": row["ticker"]}))
count += 1
logger.info("Periodic aggregation: enqueued %d tickers for re-aggregation", count)
return count
async def main() -> None:
config = load_config()
setup_logging("scheduler", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
rds = get_redis(config)
logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK)
pipeline_key = PIPELINE_ENABLED_KEY
# If PIPELINE_DEFAULT_OFF is set, initialize the toggle to OFF on first boot
# (only if the key doesn't already exist — preserves manual overrides)
if os.getenv("PIPELINE_DEFAULT_OFF", "").lower() in ("1", "true", "yes"):
created = await rds.set(pipeline_key, "0", nx=True)
if created:
logger.info("Pipeline toggle initialized to OFF (PIPELINE_DEFAULT_OFF=true)")
recovery_counter = 0
retry_counter = 0
cleanup_counter = 0
aggregation_counter = 0
report_consumer_counter = 0
report_schedule_counter = 0
try:
while True:
try:
# Check pipeline toggle — skip cycle if disabled
flag = await rds.get(pipeline_key)
if flag == "0":
await asyncio.sleep(SCHEDULER_TICK)
continue
if await acquire_lock(rds, "scheduler_cycle", ttl=30):
try:
await schedule_cycle(pool, rds)
# Run stale document recovery every ~20 cycles (~5 minutes)
recovery_counter += 1
if recovery_counter >= 20:
recovery_counter = 0
await recover_stale_documents(pool, rds)
# Retry extraction failures every ~40 cycles (~10 minutes)
retry_counter += 1
if retry_counter >= 40:
retry_counter = 0
await retry_failed_extractions(pool, rds)
# Run signal cleanup periodically (~25 minutes)
cleanup_counter += 1
if cleanup_counter >= CLEANUP_CYCLE_INTERVAL:
cleanup_counter = 0
await cleanup_all_tables(pool)
# Periodic aggregation during market hours (~15 minutes)
aggregation_counter += 1
if aggregation_counter >= AGGREGATION_CYCLE_INTERVAL:
aggregation_counter = 0
await enqueue_periodic_aggregation(pool, rds)
# Consume report generation jobs (~1 minute)
report_consumer_counter += 1
if report_consumer_counter >= REPORT_CONSUMER_CYCLE_INTERVAL:
report_consumer_counter = 0
await consume_report_generation_jobs(pool, rds)
# Check report schedule triggers (~5 minutes)
report_schedule_counter += 1
if report_schedule_counter >= REPORT_SCHEDULE_CYCLE_INTERVAL:
report_schedule_counter = 0
await check_report_schedule(rds)
finally:
await release_lock(rds, "scheduler_cycle")
except Exception:
logger.exception("Scheduler cycle error")
await asyncio.sleep(SCHEDULER_TICK)
finally:
await pool.close()
await rds.close()
# How long a document can sit in "parsed" before we consider it orphaned
# Must be longer than the expected queue drain time to avoid re-enqueuing
# docs that are already queued but not yet processed.
STALE_PARSED_THRESHOLD_MINUTES: int = 240
# How long after an extraction failure before we retry
EXTRACTION_FAILED_RETRY_MINUTES: int = 60
# Redis set key for tracking enqueued doc IDs (prevents duplicate enqueuing)
_ENQUEUED_SET = f"{QUEUE_PREFIX}:enqueued"
# How long an enqueued marker lives before it can be re-enqueued (seconds)
_ENQUEUED_TTL = 14400 # 4 hours — matches STALE_PARSED_THRESHOLD_MINUTES
async def _enqueue_if_new(
rds: aioredis.Redis,
queue: str,
document_id: str,
ticker: str,
) -> bool:
"""Push a job onto *queue* only if *document_id* isn't already tracked.
Uses a Redis SET with per-member expiry (via a separate key) to prevent
the same document from being enqueued multiple times by recovery sweeps.
Returns True if enqueued, False if skipped as duplicate.
"""
marker_key = f"{_ENQUEUED_SET}:{document_id}"
# SET NX returns True only if the key was created (not already present)
added = await rds.set(marker_key, "1", nx=True, ex=_ENQUEUED_TTL)
if not added:
return False
await rds.rpush(queue, json.dumps({
"document_id": document_id,
"ticker": ticker,
}))
return True
async def recover_stale_documents(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
"""Re-enqueue documents stuck in 'parsed' status for extraction.
Documents can get orphaned when Redis loses queue entries (pod restart,
OOM, etc.). This sweep catches any document that has been in 'parsed'
status for longer than STALE_PARSED_THRESHOLD_MINUTES and re-enqueues
it for extraction.
Returns the number of documents re-enqueued.
"""
rows = await pool.fetch(
"""SELECT d.id, d.document_type, dcm.ticker
FROM documents d
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
WHERE d.status = 'parsed'
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
AND NOT EXISTS (
SELECT 1 FROM global_events ge WHERE ge.source_document_id = d.id
)
ORDER BY d.created_at ASC
LIMIT 100""",
STALE_PARSED_THRESHOLD_MINUTES,
)
if not rows:
return 0
enqueued = 0
doc_ids = []
for row in rows:
doc_type = row["document_type"]
if doc_type == "macro_event":
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
else:
target = queue_key(QUEUE_EXTRACTION)
added = await _enqueue_if_new(rds, target, str(row["id"]), row["ticker"] or "")
if added:
doc_ids.append(row["id"])
enqueued += 1
# Touch updated_at so these docs won't be re-enqueued until the threshold passes again
if doc_ids:
await pool.execute(
"UPDATE documents SET updated_at = NOW() WHERE id = ANY($1::uuid[])",
doc_ids,
)
logger.info("Recovered %d stale parsed documents for extraction", enqueued)
return enqueued
async def retry_failed_extractions(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
"""Re-enqueue documents stuck in 'extraction_failed' for another attempt.
Resets status to 'parsed', deletes the failed intelligence row so the
extractor treats them as fresh, and pushes them onto the extraction queue.
Only retries documents whose last attempt was at least
EXTRACTION_FAILED_RETRY_MINUTES ago to avoid tight retry loops.
Returns the number of documents re-enqueued.
"""
rows = await pool.fetch(
"""SELECT d.id, d.document_type, dcm.ticker
FROM documents d
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
WHERE d.status = 'extraction_failed'
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
ORDER BY d.updated_at ASC
LIMIT 100""",
EXTRACTION_FAILED_RETRY_MINUTES,
)
if not rows:
return 0
enqueued = 0
doc_ids = []
for row in rows:
doc_type = row["document_type"]
if doc_type == "macro_event":
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
else:
target = queue_key(QUEUE_EXTRACTION)
added = await _enqueue_if_new(rds, target, str(row["id"]), row["ticker"] or "")
if added:
doc_ids.append(row["id"])
enqueued += 1
if doc_ids:
# Delete failed intelligence rows so extractor starts fresh
await pool.execute(
"""DELETE FROM document_intelligence
WHERE document_id = ANY($1::uuid[])
AND validation_status = 'failed'""",
doc_ids,
)
# Reset status to 'parsed' and touch updated_at
await pool.execute(
"""UPDATE documents
SET status = 'parsed', updated_at = NOW()
WHERE id = ANY($1::uuid[])""",
doc_ids,
)
logger.info("Retried %d extraction-failed documents", enqueued)
return enqueued
# How often to run competitive signal cleanup (every ~100 cycles = ~25 minutes)
CLEANUP_CYCLE_INTERVAL: int = 100
# Keep competitive signals for this many days
COMPETITIVE_SIGNAL_RETENTION_DAYS: int = 30
async def cleanup_old_signals(pool: asyncpg.Pool) -> int:
"""Delete competitive signal records older than the retention window.
Prevents the competitive_signal_records table from growing unbounded.
Returns the number of rows deleted.
"""
result = await pool.execute(
"DELETE FROM competitive_signal_records WHERE computed_at < NOW() - INTERVAL '1 day' * $1",
COMPETITIVE_SIGNAL_RETENTION_DAYS,
)
# result is like "DELETE 1234"
count = int(result.split()[-1]) if result else 0
if count > 0:
logger.info("Cleaned up %d old competitive signal records", count)
return count
async def cleanup_all_tables(pool: asyncpg.Pool) -> None:
"""Run retention cleanup across all tables that grow unbounded.
Called periodically by the scheduler (~every 25 minutes).
Each table has its own retention window.
"""
cleanups = [
# (table, time_column, retention_days, description)
("competitive_signal_records", "computed_at", 30, "competitive signals"),
("ingestion_runs", "started_at", 14, "ingestion runs"),
("trading_decisions", "created_at", 90, "trading decisions"),
("risk_evaluations", "evaluated_at", 30, "risk evaluations"),
("audit_events", "created_at", 90, "audit events"),
("macro_impact_records", "computed_at", 60, "macro impact records"),
("recommendation_evidence", "created_at", 60, "recommendation evidence"),
("recommendations", "generated_at", 30, "old recommendations"),
("order_events", "created_at", 90, "order events"),
("model_performance_metrics", "recorded_at", 30, "model metrics"),
]
total = 0
for table, col, days, desc in cleanups:
try:
result = await pool.execute(
f"DELETE FROM {table} WHERE {col} < NOW() - INTERVAL '1 day' * $1", # noqa: S608
days,
)
count = int(result.split()[-1]) if result else 0
if count > 0:
total += count
logger.info("Cleaned %d %s (>%dd old)", count, desc, days)
except Exception:
# Table might not exist or column name might differ — skip
pass
if total > 0:
logger.info("Total cleanup: %d rows deleted across all tables", total)
if __name__ == "__main__":
asyncio.run(main())