6f54fd07fa
ci/woodpecker/push/test Pipeline was successful
ci/woodpecker/push/build-2 Pipeline was successful
ci/woodpecker/push/build-3 Pipeline was successful
ci/woodpecker/push/build-1 Pipeline was successful
ci/woodpecker/push/finalize Pipeline was successful
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
The aggregation engine only ran when new documents were ingested, leaving intraday trend data stale for long periods. Now the scheduler enqueues all 50 tickers for re-aggregation every ~15 minutes during US market hours (Mon-Fri, 6:30 AM - 1:30 PM PT). This ensures continuous intraday trend updates based on existing signals and market price changes.
813 lines
28 KiB
Python
813 lines
28 KiB
Python
"""Scheduler - triggers ingestion cycles for tracked symbols and sources.
|
||
|
||
Polls the symbol registry for active companies and their configured sources,
|
||
respects per-source polling cadences and backoff windows, coordinates rate
|
||
limits across source types, and enqueues ingestion jobs for downstream workers.
|
||
|
||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
from datetime import datetime, timezone
|
||
from typing import Any, Optional
|
||
|
||
import asyncpg
|
||
import redis.asyncio as aioredis
|
||
|
||
from services.shared.config import load_config
|
||
from services.shared.db import get_pg_pool, get_redis
|
||
from services.shared.logging import setup_logging
|
||
from services.shared.redis_keys import (
|
||
PIPELINE_ENABLED_KEY,
|
||
QUEUE_AGGREGATION,
|
||
QUEUE_EXTRACTION,
|
||
QUEUE_INGESTION,
|
||
QUEUE_MACRO_CLASSIFICATION,
|
||
QUEUE_PREFIX,
|
||
lock_key,
|
||
queue_key,
|
||
rate_limit_key,
|
||
)
|
||
|
||
logger = logging.getLogger("scheduler")
|
||
|
||
|
||
def _ensure_dict(val: Any) -> Optional[dict]:
|
||
"""Coerce a JSONB value (dict or JSON string) to a Python dict."""
|
||
if val is None:
|
||
return None
|
||
if isinstance(val, dict):
|
||
return val
|
||
if isinstance(val, str):
|
||
try:
|
||
parsed = json.loads(val)
|
||
return parsed if isinstance(parsed, dict) else None
|
||
except (json.JSONDecodeError, TypeError):
|
||
return None
|
||
return None
|
||
|
||
# Default polling cadences by source class (seconds).
|
||
# Individual sources can override via config.polling_interval_seconds.
|
||
DEFAULT_CADENCES: dict[str, int] = {
|
||
"market_api": 300,
|
||
"news_api": 300,
|
||
"filings_api": 3600,
|
||
"web_scrape": 1800,
|
||
"broker": 30,
|
||
"macro_news": 600,
|
||
}
|
||
|
||
# Default rate limits per source type (requests per minute)
|
||
DEFAULT_RATE_LIMITS: dict[str, int] = {
|
||
"market_api": 20,
|
||
"news_api": 20,
|
||
"filings_api": 10,
|
||
"web_scrape": 10,
|
||
"broker": 60,
|
||
"macro_news": 10,
|
||
}
|
||
|
||
# Global rate limit across all Polygon-backed source types (requests per minute).
|
||
# market_api + news_api share a single Polygon API key, so we cap the combined
|
||
# throughput to stay safely under the plan limit.
|
||
POLYGON_SOURCE_TYPES: set[str] = {"market_api", "news_api"}
|
||
POLYGON_GLOBAL_RATE_LIMIT: int = 45
|
||
|
||
# How long to wait before retrying a failed source (seconds)
|
||
DEFAULT_BACKOFF_BASE: int = 60
|
||
MAX_BACKOFF: int = 3600
|
||
MAX_RETRY_COUNT: int = 10
|
||
|
||
# Main loop interval (seconds)
|
||
SCHEDULER_TICK: int = 15
|
||
|
||
# Periodic aggregation: re-aggregate all tickers every N cycles during market hours
|
||
# 15s tick × 60 cycles = 15 minutes
|
||
AGGREGATION_CYCLE_INTERVAL: int = 60
|
||
|
||
|
||
def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int:
|
||
"""Return the polling interval for a source.
|
||
|
||
Uses the source's config.polling_interval_seconds if set,
|
||
otherwise falls back to the default cadence for the source type.
|
||
"""
|
||
if config and "polling_interval_seconds" in config:
|
||
try:
|
||
return max(10, int(config["polling_interval_seconds"]))
|
||
except (ValueError, TypeError):
|
||
pass
|
||
return DEFAULT_CADENCES.get(source_type, 600)
|
||
|
||
|
||
def compute_backoff(retry_count: int) -> int:
|
||
"""Exponential backoff with a cap. Returns seconds to wait."""
|
||
delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8))
|
||
return min(delay, MAX_BACKOFF)
|
||
|
||
|
||
def is_source_due(
|
||
source_type: str,
|
||
source_config: Optional[dict[str, Any]],
|
||
last_completed_at: Optional[datetime],
|
||
last_status: Optional[str],
|
||
retry_count: int,
|
||
next_retry_at: Optional[datetime],
|
||
now: datetime,
|
||
) -> bool:
|
||
"""Determine whether a source is due for its next polling cycle.
|
||
|
||
Checks:
|
||
- If the source has never run, it is due.
|
||
- If the last run failed and we have a next_retry_at in the future, skip.
|
||
- If the last run failed and retry_count exceeds max, skip (needs manual reset).
|
||
- Otherwise, check if enough time has elapsed since the last completed run.
|
||
"""
|
||
# Never run before — always due
|
||
if last_completed_at is None and last_status is None:
|
||
return True
|
||
|
||
# If last run failed, respect backoff
|
||
if last_status == "failed":
|
||
if retry_count >= MAX_RETRY_COUNT:
|
||
return False
|
||
if next_retry_at:
|
||
# Normalize tz-awareness to match 'now'
|
||
if now.tzinfo is not None and next_retry_at.tzinfo is None:
|
||
nra = next_retry_at.replace(tzinfo=timezone.utc)
|
||
elif now.tzinfo is None and next_retry_at.tzinfo is not None:
|
||
nra = next_retry_at.replace(tzinfo=None)
|
||
else:
|
||
nra = next_retry_at
|
||
if now < nra:
|
||
return False
|
||
# Backoff elapsed or no next_retry_at set — allow retry
|
||
return True
|
||
|
||
# If currently running, don't double-schedule
|
||
if last_status == "running":
|
||
return False
|
||
|
||
# Normal cadence check
|
||
if last_completed_at is None:
|
||
return True
|
||
|
||
cadence = get_cadence_for_source(source_type, source_config)
|
||
# Ensure both datetimes have matching tz-awareness for subtraction
|
||
if now.tzinfo is not None and last_completed_at.tzinfo is None:
|
||
last_completed_at = last_completed_at.replace(tzinfo=timezone.utc)
|
||
elif now.tzinfo is None and last_completed_at.tzinfo is not None:
|
||
last_completed_at = last_completed_at.replace(tzinfo=None)
|
||
elapsed = (now - last_completed_at).total_seconds()
|
||
return elapsed >= cadence
|
||
|
||
|
||
def build_job_payload(
|
||
source: Any,
|
||
aliases: list[str],
|
||
now: datetime,
|
||
) -> dict[str, Any]:
|
||
"""Build the ingestion job payload for a source."""
|
||
return {
|
||
"source_id": str(source["source_id"]),
|
||
"company_id": str(source["company_id"]) if source.get("company_id") else None,
|
||
"ticker": source.get("ticker") or "",
|
||
"legal_name": source.get("legal_name") or "",
|
||
"aliases": aliases,
|
||
"source_type": source["source_type"],
|
||
"source_name": source["source_name"],
|
||
"config": _ensure_dict(source["config"]) or {},
|
||
"credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5,
|
||
"scheduled_at": now.isoformat(),
|
||
}
|
||
|
||
|
||
async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool:
|
||
"""Acquire a distributed lock. Returns True if acquired."""
|
||
return await rds.set(lock_key(name), "1", nx=True, ex=ttl)
|
||
|
||
|
||
async def release_lock(rds: aioredis.Redis, name: str) -> None:
|
||
"""Release a distributed lock."""
|
||
await rds.delete(lock_key(name))
|
||
|
||
|
||
async def check_rate_limit(
|
||
rds: aioredis.Redis,
|
||
source_type: str,
|
||
now: datetime,
|
||
max_per_minute: Optional[int] = None,
|
||
) -> bool:
|
||
"""Check whether the source type is within its rate limit window.
|
||
|
||
Enforces two limits:
|
||
1. Per-source-type limit (e.g. market_api: 20/min)
|
||
2. Global Polygon limit across all Polygon-backed types (45/min combined)
|
||
|
||
Returns True if the request is allowed, False if rate-limited.
|
||
"""
|
||
limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30)
|
||
window = now.strftime("%Y%m%d%H%M")
|
||
|
||
# Per-source-type check
|
||
key = rate_limit_key(source_type, window)
|
||
count = await rds.incr(key)
|
||
if count == 1:
|
||
await rds.expire(key, 120)
|
||
if count > limit:
|
||
return False
|
||
|
||
# Global Polygon check for source types that share the Polygon API key
|
||
if source_type in POLYGON_SOURCE_TYPES:
|
||
global_key = rate_limit_key("_polygon_global", window)
|
||
global_count = await rds.incr(global_key)
|
||
if global_count == 1:
|
||
await rds.expire(global_key, 120)
|
||
if global_count > POLYGON_GLOBAL_RATE_LIMIT:
|
||
# Roll back the per-type counter since we won't actually make the call
|
||
await rds.decr(key)
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||
"""Fetch all active company-specific sources joined with their active companies."""
|
||
return await pool.fetch(
|
||
"""SELECT s.id AS source_id,
|
||
s.company_id,
|
||
s.source_type,
|
||
s.source_name,
|
||
s.config,
|
||
s.credibility_score,
|
||
c.ticker,
|
||
c.legal_name
|
||
FROM sources s
|
||
JOIN companies c ON s.company_id = c.id
|
||
WHERE s.active = TRUE AND c.active = TRUE
|
||
AND s.source_type != 'macro_news'
|
||
ORDER BY s.source_type, c.ticker"""
|
||
)
|
||
|
||
|
||
async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||
"""Fetch all active macro news sources.
|
||
|
||
Macro sources are not company-specific — they have source_type='macro_news'
|
||
and may have company_id NULL. They are scheduled independently from
|
||
company-specific sources.
|
||
|
||
Requirements: 1.1
|
||
"""
|
||
return await pool.fetch(
|
||
"""SELECT s.id AS source_id,
|
||
s.company_id,
|
||
s.source_type,
|
||
s.source_name,
|
||
s.config,
|
||
s.credibility_score
|
||
FROM sources s
|
||
WHERE s.active = TRUE AND s.source_type = 'macro_news'
|
||
ORDER BY s.source_name"""
|
||
)
|
||
|
||
|
||
async def fetch_global_market_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||
"""Fetch active market sources that are not company-specific.
|
||
|
||
These are sources like the grouped daily endpoint that fetch data for
|
||
all tickers in a single API call. They have company_id IS NULL and
|
||
source_type = 'market_api'.
|
||
"""
|
||
return await pool.fetch(
|
||
"""SELECT s.id AS source_id,
|
||
s.company_id,
|
||
s.source_type,
|
||
s.source_name,
|
||
s.config,
|
||
s.credibility_score
|
||
FROM sources s
|
||
WHERE s.active = TRUE
|
||
AND s.source_type = 'market_api'
|
||
AND s.company_id IS NULL
|
||
ORDER BY s.source_name"""
|
||
)
|
||
|
||
|
||
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
|
||
"""Fetch all aliases for a company."""
|
||
rows = await pool.fetch(
|
||
"SELECT alias FROM company_aliases WHERE company_id = $1",
|
||
company_id,
|
||
)
|
||
return [r["alias"] for r in rows]
|
||
|
||
|
||
async def fetch_last_run(
|
||
pool: asyncpg.Pool, source_id: str
|
||
) -> Optional[asyncpg.Record]:
|
||
"""Fetch the most recent ingestion run for a source."""
|
||
return await pool.fetchrow(
|
||
"""SELECT status, started_at, completed_at, retry_count, next_retry_at
|
||
FROM ingestion_runs
|
||
WHERE source_id = $1
|
||
ORDER BY started_at DESC
|
||
LIMIT 1""",
|
||
source_id,
|
||
)
|
||
|
||
|
||
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||
"""One scheduling pass: find due sources and enqueue ingestion jobs.
|
||
|
||
Returns the number of jobs enqueued.
|
||
"""
|
||
now = datetime.now(tz=timezone.utc)
|
||
sources = await fetch_active_sources(pool)
|
||
|
||
enqueued = 0
|
||
skipped_rate_limit = 0
|
||
skipped_not_due = 0
|
||
|
||
for src in sources:
|
||
source_id = src["source_id"]
|
||
source_type = src["source_type"]
|
||
source_config = _ensure_dict(src["config"])
|
||
|
||
# Check last run status and timing
|
||
last_run = await fetch_last_run(pool, source_id)
|
||
|
||
last_completed_at = None
|
||
last_status = None
|
||
retry_count = 0
|
||
next_retry_at = None
|
||
|
||
if last_run:
|
||
last_status = last_run["status"]
|
||
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
||
retry_count = last_run["retry_count"] or 0
|
||
next_retry_at = last_run["next_retry_at"]
|
||
|
||
if not is_source_due(
|
||
source_type=source_type,
|
||
source_config=source_config,
|
||
last_completed_at=last_completed_at,
|
||
last_status=last_status,
|
||
retry_count=retry_count,
|
||
next_retry_at=next_retry_at,
|
||
now=now,
|
||
):
|
||
skipped_not_due += 1
|
||
continue
|
||
|
||
# Rate limit check
|
||
if not await check_rate_limit(rds, source_type, now):
|
||
logger.warning(
|
||
"Rate limit hit for %s, skipping %s/%s",
|
||
source_type, src["ticker"], src["source_name"],
|
||
)
|
||
skipped_rate_limit += 1
|
||
continue
|
||
|
||
# Fetch company aliases for downstream entity matching
|
||
aliases = await fetch_aliases_for_company(pool, src["company_id"])
|
||
|
||
job = build_job_payload(src, aliases, now)
|
||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc]
|
||
enqueued += 1
|
||
|
||
logger.debug(
|
||
"Enqueued %s job for %s (%s)",
|
||
source_type, src["ticker"], src["source_name"],
|
||
)
|
||
|
||
# --- Schedule macro news sources (Requirement 1.1) ---
|
||
macro_sources = await fetch_macro_sources(pool)
|
||
for src in macro_sources:
|
||
source_id = src["source_id"]
|
||
source_type = src["source_type"]
|
||
source_config = _ensure_dict(src["config"])
|
||
|
||
last_run = await fetch_last_run(pool, source_id)
|
||
|
||
last_completed_at = None
|
||
last_status = None
|
||
retry_count = 0
|
||
next_retry_at = None
|
||
|
||
if last_run:
|
||
last_status = last_run["status"]
|
||
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
||
retry_count = last_run["retry_count"] or 0
|
||
next_retry_at = last_run["next_retry_at"]
|
||
|
||
if not is_source_due(
|
||
source_type=source_type,
|
||
source_config=source_config,
|
||
last_completed_at=last_completed_at,
|
||
last_status=last_status,
|
||
retry_count=retry_count,
|
||
next_retry_at=next_retry_at,
|
||
now=now,
|
||
):
|
||
skipped_not_due += 1
|
||
continue
|
||
|
||
if not await check_rate_limit(rds, source_type, now):
|
||
logger.warning(
|
||
"Rate limit hit for macro_news, skipping %s",
|
||
src["source_name"],
|
||
)
|
||
skipped_rate_limit += 1
|
||
continue
|
||
|
||
job = build_job_payload(src, [], now)
|
||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
|
||
enqueued += 1
|
||
|
||
logger.debug(
|
||
"Enqueued macro_news job for %s", src["source_name"],
|
||
)
|
||
|
||
# --- Schedule global market sources (grouped daily, etc.) ---
|
||
global_market_sources = await fetch_global_market_sources(pool)
|
||
for src in global_market_sources:
|
||
source_id = src["source_id"]
|
||
source_type = src["source_type"]
|
||
source_config = _ensure_dict(src["config"])
|
||
|
||
# Use a longer cadence for grouped daily (once per day = 86400s)
|
||
endpoint = source_config.get("endpoint", "")
|
||
if endpoint == "grouped_daily":
|
||
source_config.setdefault("polling_interval_seconds", 86400)
|
||
|
||
last_run = await fetch_last_run(pool, source_id)
|
||
|
||
last_completed_at = None
|
||
last_status = None
|
||
retry_count = 0
|
||
next_retry_at = None
|
||
|
||
if last_run:
|
||
last_status = last_run["status"]
|
||
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
||
retry_count = last_run["retry_count"] or 0
|
||
next_retry_at = last_run["next_retry_at"]
|
||
|
||
if not is_source_due(
|
||
source_type=source_type,
|
||
source_config=source_config,
|
||
last_completed_at=last_completed_at,
|
||
last_status=last_status,
|
||
retry_count=retry_count,
|
||
next_retry_at=next_retry_at,
|
||
now=now,
|
||
):
|
||
skipped_not_due += 1
|
||
continue
|
||
|
||
if not await check_rate_limit(rds, source_type, now):
|
||
skipped_rate_limit += 1
|
||
continue
|
||
|
||
# Build job with ticker="_MARKET" for global sources
|
||
job = build_job_payload(src, [], now)
|
||
if endpoint == "intraday_bars":
|
||
# Expand intraday source into per-ticker jobs for all active companies
|
||
tickers = await pool.fetch(
|
||
"SELECT ticker FROM companies WHERE active = TRUE"
|
||
)
|
||
for t_row in tickers:
|
||
ticker_job = dict(job)
|
||
ticker_job["ticker"] = t_row["ticker"]
|
||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(ticker_job))
|
||
enqueued += 1
|
||
logger.info("Enqueued %d intraday bar jobs", len(tickers))
|
||
else:
|
||
job["ticker"] = "_MARKET"
|
||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
|
||
enqueued += 1
|
||
logger.info("Enqueued grouped daily market data job")
|
||
|
||
logger.info(
|
||
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
|
||
enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources) + len(global_market_sources),
|
||
)
|
||
return enqueued
|
||
|
||
|
||
async def enqueue_periodic_aggregation(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||
"""Enqueue aggregation jobs for all active tickers.
|
||
|
||
Runs periodically during market hours to ensure trend data stays fresh
|
||
even when no new documents are being ingested. This gives the intraday
|
||
and 1d windows continuous updates based on existing signals and market
|
||
price changes.
|
||
"""
|
||
# Only run during US market hours (Mon-Fri, 6:30 AM - 1:30 PM PT / 13:30-20:30 UTC)
|
||
from datetime import datetime, timezone
|
||
now = datetime.now(timezone.utc)
|
||
weekday = now.weekday() # 0=Mon, 6=Sun
|
||
hour_utc = now.hour + now.minute / 60.0
|
||
|
||
if weekday >= 5: # Weekend
|
||
return 0
|
||
if hour_utc < 13.5 or hour_utc > 20.5: # Outside market hours (with 30min buffer)
|
||
return 0
|
||
|
||
# Fetch all active tickers
|
||
rows = await pool.fetch(
|
||
"SELECT ticker FROM companies WHERE active = TRUE ORDER BY ticker"
|
||
)
|
||
if not rows:
|
||
return 0
|
||
|
||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||
count = 0
|
||
for row in rows:
|
||
await rds.rpush(agg_queue, json.dumps({"ticker": row["ticker"]}))
|
||
count += 1
|
||
|
||
logger.info("Periodic aggregation: enqueued %d tickers for re-aggregation", count)
|
||
return count
|
||
|
||
|
||
async def main() -> None:
|
||
config = load_config()
|
||
setup_logging("scheduler", level=config.log_level, json_output=config.json_logs)
|
||
|
||
pool = await get_pg_pool(config)
|
||
rds = get_redis(config)
|
||
|
||
logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK)
|
||
pipeline_key = PIPELINE_ENABLED_KEY
|
||
|
||
# If PIPELINE_DEFAULT_OFF is set, initialize the toggle to OFF on first boot
|
||
# (only if the key doesn't already exist — preserves manual overrides)
|
||
if os.getenv("PIPELINE_DEFAULT_OFF", "").lower() in ("1", "true", "yes"):
|
||
created = await rds.set(pipeline_key, "0", nx=True)
|
||
if created:
|
||
logger.info("Pipeline toggle initialized to OFF (PIPELINE_DEFAULT_OFF=true)")
|
||
|
||
recovery_counter = 0
|
||
retry_counter = 0
|
||
cleanup_counter = 0
|
||
aggregation_counter = 0
|
||
try:
|
||
while True:
|
||
try:
|
||
# Check pipeline toggle — skip cycle if disabled
|
||
flag = await rds.get(pipeline_key)
|
||
if flag == "0":
|
||
await asyncio.sleep(SCHEDULER_TICK)
|
||
continue
|
||
|
||
if await acquire_lock(rds, "scheduler_cycle", ttl=30):
|
||
try:
|
||
await schedule_cycle(pool, rds)
|
||
# Run stale document recovery every ~20 cycles (~5 minutes)
|
||
recovery_counter += 1
|
||
if recovery_counter >= 20:
|
||
recovery_counter = 0
|
||
await recover_stale_documents(pool, rds)
|
||
# Retry extraction failures every ~40 cycles (~10 minutes)
|
||
retry_counter += 1
|
||
if retry_counter >= 40:
|
||
retry_counter = 0
|
||
await retry_failed_extractions(pool, rds)
|
||
# Run signal cleanup periodically (~25 minutes)
|
||
cleanup_counter += 1
|
||
if cleanup_counter >= CLEANUP_CYCLE_INTERVAL:
|
||
cleanup_counter = 0
|
||
await cleanup_all_tables(pool)
|
||
# Periodic aggregation during market hours (~15 minutes)
|
||
aggregation_counter += 1
|
||
if aggregation_counter >= AGGREGATION_CYCLE_INTERVAL:
|
||
aggregation_counter = 0
|
||
await enqueue_periodic_aggregation(pool, rds)
|
||
finally:
|
||
await release_lock(rds, "scheduler_cycle")
|
||
except Exception:
|
||
logger.exception("Scheduler cycle error")
|
||
await asyncio.sleep(SCHEDULER_TICK)
|
||
finally:
|
||
await pool.close()
|
||
await rds.close()
|
||
|
||
|
||
# How long a document can sit in "parsed" before we consider it orphaned
|
||
# Must be longer than the expected queue drain time to avoid re-enqueuing
|
||
# docs that are already queued but not yet processed.
|
||
STALE_PARSED_THRESHOLD_MINUTES: int = 240
|
||
|
||
# How long after an extraction failure before we retry
|
||
EXTRACTION_FAILED_RETRY_MINUTES: int = 60
|
||
|
||
# Redis set key for tracking enqueued doc IDs (prevents duplicate enqueuing)
|
||
_ENQUEUED_SET = f"{QUEUE_PREFIX}:enqueued"
|
||
# How long an enqueued marker lives before it can be re-enqueued (seconds)
|
||
_ENQUEUED_TTL = 14400 # 4 hours — matches STALE_PARSED_THRESHOLD_MINUTES
|
||
|
||
|
||
async def _enqueue_if_new(
|
||
rds: aioredis.Redis,
|
||
queue: str,
|
||
document_id: str,
|
||
ticker: str,
|
||
) -> bool:
|
||
"""Push a job onto *queue* only if *document_id* isn't already tracked.
|
||
|
||
Uses a Redis SET with per-member expiry (via a separate key) to prevent
|
||
the same document from being enqueued multiple times by recovery sweeps.
|
||
Returns True if enqueued, False if skipped as duplicate.
|
||
"""
|
||
marker_key = f"{_ENQUEUED_SET}:{document_id}"
|
||
# SET NX returns True only if the key was created (not already present)
|
||
added = await rds.set(marker_key, "1", nx=True, ex=_ENQUEUED_TTL)
|
||
if not added:
|
||
return False
|
||
await rds.rpush(queue, json.dumps({
|
||
"document_id": document_id,
|
||
"ticker": ticker,
|
||
}))
|
||
return True
|
||
|
||
|
||
async def recover_stale_documents(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||
"""Re-enqueue documents stuck in 'parsed' status for extraction.
|
||
|
||
Documents can get orphaned when Redis loses queue entries (pod restart,
|
||
OOM, etc.). This sweep catches any document that has been in 'parsed'
|
||
status for longer than STALE_PARSED_THRESHOLD_MINUTES and re-enqueues
|
||
it for extraction.
|
||
|
||
Returns the number of documents re-enqueued.
|
||
"""
|
||
rows = await pool.fetch(
|
||
"""SELECT d.id, d.document_type, dcm.ticker
|
||
FROM documents d
|
||
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
|
||
WHERE d.status = 'parsed'
|
||
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
|
||
AND NOT EXISTS (
|
||
SELECT 1 FROM global_events ge WHERE ge.source_document_id = d.id
|
||
)
|
||
ORDER BY d.created_at ASC
|
||
LIMIT 100""",
|
||
STALE_PARSED_THRESHOLD_MINUTES,
|
||
)
|
||
|
||
if not rows:
|
||
return 0
|
||
|
||
enqueued = 0
|
||
doc_ids = []
|
||
for row in rows:
|
||
doc_type = row["document_type"]
|
||
if doc_type == "macro_event":
|
||
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||
else:
|
||
target = queue_key(QUEUE_EXTRACTION)
|
||
|
||
added = await _enqueue_if_new(rds, target, str(row["id"]), row["ticker"] or "")
|
||
if added:
|
||
doc_ids.append(row["id"])
|
||
enqueued += 1
|
||
|
||
# Touch updated_at so these docs won't be re-enqueued until the threshold passes again
|
||
if doc_ids:
|
||
await pool.execute(
|
||
"UPDATE documents SET updated_at = NOW() WHERE id = ANY($1::uuid[])",
|
||
doc_ids,
|
||
)
|
||
|
||
logger.info("Recovered %d stale parsed documents for extraction", enqueued)
|
||
return enqueued
|
||
|
||
|
||
async def retry_failed_extractions(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||
"""Re-enqueue documents stuck in 'extraction_failed' for another attempt.
|
||
|
||
Resets status to 'parsed', deletes the failed intelligence row so the
|
||
extractor treats them as fresh, and pushes them onto the extraction queue.
|
||
|
||
Only retries documents whose last attempt was at least
|
||
EXTRACTION_FAILED_RETRY_MINUTES ago to avoid tight retry loops.
|
||
|
||
Returns the number of documents re-enqueued.
|
||
"""
|
||
rows = await pool.fetch(
|
||
"""SELECT d.id, d.document_type, dcm.ticker
|
||
FROM documents d
|
||
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
|
||
WHERE d.status = 'extraction_failed'
|
||
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
|
||
ORDER BY d.updated_at ASC
|
||
LIMIT 100""",
|
||
EXTRACTION_FAILED_RETRY_MINUTES,
|
||
)
|
||
|
||
if not rows:
|
||
return 0
|
||
|
||
enqueued = 0
|
||
doc_ids = []
|
||
for row in rows:
|
||
doc_type = row["document_type"]
|
||
if doc_type == "macro_event":
|
||
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||
else:
|
||
target = queue_key(QUEUE_EXTRACTION)
|
||
|
||
added = await _enqueue_if_new(rds, target, str(row["id"]), row["ticker"] or "")
|
||
if added:
|
||
doc_ids.append(row["id"])
|
||
enqueued += 1
|
||
|
||
if doc_ids:
|
||
# Delete failed intelligence rows so extractor starts fresh
|
||
await pool.execute(
|
||
"""DELETE FROM document_intelligence
|
||
WHERE document_id = ANY($1::uuid[])
|
||
AND validation_status = 'failed'""",
|
||
doc_ids,
|
||
)
|
||
# Reset status to 'parsed' and touch updated_at
|
||
await pool.execute(
|
||
"""UPDATE documents
|
||
SET status = 'parsed', updated_at = NOW()
|
||
WHERE id = ANY($1::uuid[])""",
|
||
doc_ids,
|
||
)
|
||
|
||
logger.info("Retried %d extraction-failed documents", enqueued)
|
||
return enqueued
|
||
|
||
|
||
# How often to run competitive signal cleanup (every ~100 cycles = ~25 minutes)
|
||
CLEANUP_CYCLE_INTERVAL: int = 100
|
||
# Keep competitive signals for this many days
|
||
COMPETITIVE_SIGNAL_RETENTION_DAYS: int = 30
|
||
|
||
|
||
async def cleanup_old_signals(pool: asyncpg.Pool) -> int:
|
||
"""Delete competitive signal records older than the retention window.
|
||
|
||
Prevents the competitive_signal_records table from growing unbounded.
|
||
Returns the number of rows deleted.
|
||
"""
|
||
result = await pool.execute(
|
||
"DELETE FROM competitive_signal_records WHERE computed_at < NOW() - INTERVAL '1 day' * $1",
|
||
COMPETITIVE_SIGNAL_RETENTION_DAYS,
|
||
)
|
||
# result is like "DELETE 1234"
|
||
count = int(result.split()[-1]) if result else 0
|
||
if count > 0:
|
||
logger.info("Cleaned up %d old competitive signal records", count)
|
||
return count
|
||
|
||
|
||
async def cleanup_all_tables(pool: asyncpg.Pool) -> None:
|
||
"""Run retention cleanup across all tables that grow unbounded.
|
||
|
||
Called periodically by the scheduler (~every 25 minutes).
|
||
Each table has its own retention window.
|
||
"""
|
||
cleanups = [
|
||
# (table, time_column, retention_days, description)
|
||
("competitive_signal_records", "computed_at", 30, "competitive signals"),
|
||
("ingestion_runs", "started_at", 14, "ingestion runs"),
|
||
("trading_decisions", "created_at", 90, "trading decisions"),
|
||
("risk_evaluations", "evaluated_at", 30, "risk evaluations"),
|
||
("audit_events", "created_at", 90, "audit events"),
|
||
("macro_impact_records", "computed_at", 60, "macro impact records"),
|
||
("recommendation_evidence", "created_at", 60, "recommendation evidence"),
|
||
("recommendations", "generated_at", 30, "old recommendations"),
|
||
("order_events", "created_at", 90, "order events"),
|
||
("model_performance_metrics", "recorded_at", 30, "model metrics"),
|
||
]
|
||
|
||
total = 0
|
||
for table, col, days, desc in cleanups:
|
||
try:
|
||
result = await pool.execute(
|
||
f"DELETE FROM {table} WHERE {col} < NOW() - INTERVAL '1 day' * $1", # noqa: S608
|
||
days,
|
||
)
|
||
count = int(result.split()[-1]) if result else 0
|
||
if count > 0:
|
||
total += count
|
||
logger.info("Cleaned %d %s (>%dd old)", count, desc, days)
|
||
except Exception:
|
||
# Table might not exist or column name might differ — skip
|
||
pass
|
||
|
||
if total > 0:
|
||
logger.info("Total cleanup: %d rows deleted across all tables", total)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|