bc077bfcc8
ci/woodpecker/push/test Pipeline was successful
ci/woodpecker/push/build-2 Pipeline was successful
ci/woodpecker/push/build-3 Pipeline was successful
ci/woodpecker/push/build-1 Pipeline was successful
ci/woodpecker/push/finalize Pipeline was successful
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
- Migration 038: trading_reports table + report-summarizer agent seed
- 6 reporting modules: models, collector, sections, validator, summarizer, generator
- API endpoints: GET /api/reports (paginated, filterable), GET /api/reports/{id}
- Frontend hooks: useReports, useReport with TanStack Query
- Scheduler: daily (after 16:30 ET) and weekly (Saturday) report triggers
- Redis queue consumer for async report generation with retry/dedup
- 5 property-based tests (chunking, serialization, validation, accuracy, deltas)
- 109 unit/integration tests across all modules
- 6 frontend hook tests with MSW mocks
973 lines
33 KiB
Python
973 lines
33 KiB
Python
"""Scheduler - triggers ingestion cycles for tracked symbols and sources.
|
||
|
||
Polls the symbol registry for active companies and their configured sources,
|
||
respects per-source polling cadences and backoff windows, coordinates rate
|
||
limits across source types, and enqueues ingestion jobs for downstream workers.
|
||
|
||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any, Optional
|
||
from zoneinfo import ZoneInfo
|
||
|
||
import asyncpg
|
||
import redis.asyncio as aioredis
|
||
|
||
from services.shared.config import load_config
|
||
from services.shared.db import get_pg_pool, get_redis
|
||
from services.shared.logging import setup_logging
|
||
from services.shared.redis_keys import (
|
||
PIPELINE_ENABLED_KEY,
|
||
QUEUE_AGGREGATION,
|
||
QUEUE_EXTRACTION,
|
||
QUEUE_INGESTION,
|
||
QUEUE_MACRO_CLASSIFICATION,
|
||
QUEUE_PREFIX,
|
||
QUEUE_REPORT_GENERATION,
|
||
lock_key,
|
||
queue_key,
|
||
rate_limit_key,
|
||
)
|
||
|
||
logger = logging.getLogger("scheduler")
|
||
|
||
|
||
def _ensure_dict(val: Any) -> Optional[dict]:
|
||
"""Coerce a JSONB value (dict or JSON string) to a Python dict."""
|
||
if val is None:
|
||
return None
|
||
if isinstance(val, dict):
|
||
return val
|
||
if isinstance(val, str):
|
||
try:
|
||
parsed = json.loads(val)
|
||
return parsed if isinstance(parsed, dict) else None
|
||
except (json.JSONDecodeError, TypeError):
|
||
return None
|
||
return None
|
||
|
||
# Default polling cadences by source class (seconds).
|
||
# Individual sources can override via config.polling_interval_seconds.
|
||
DEFAULT_CADENCES: dict[str, int] = {
|
||
"market_api": 300,
|
||
"news_api": 300,
|
||
"filings_api": 3600,
|
||
"web_scrape": 1800,
|
||
"broker": 30,
|
||
"macro_news": 600,
|
||
}
|
||
|
||
# Default rate limits per source type (requests per minute)
|
||
DEFAULT_RATE_LIMITS: dict[str, int] = {
|
||
"market_api": 20,
|
||
"news_api": 20,
|
||
"filings_api": 10,
|
||
"web_scrape": 10,
|
||
"broker": 60,
|
||
"macro_news": 10,
|
||
}
|
||
|
||
# Global rate limit across all Polygon-backed source types (requests per minute).
|
||
# market_api + news_api share a single Polygon API key, so we cap the combined
|
||
# throughput to stay safely under the plan limit.
|
||
POLYGON_SOURCE_TYPES: set[str] = {"market_api", "news_api"}
|
||
POLYGON_GLOBAL_RATE_LIMIT: int = 45
|
||
|
||
# How long to wait before retrying a failed source (seconds)
|
||
DEFAULT_BACKOFF_BASE: int = 60
|
||
MAX_BACKOFF: int = 3600
|
||
MAX_RETRY_COUNT: int = 10
|
||
|
||
# Main loop interval (seconds)
|
||
SCHEDULER_TICK: int = 15
|
||
|
||
# Periodic aggregation: re-aggregate all tickers every N cycles during market hours
|
||
# 15s tick × 60 cycles = 15 minutes
|
||
AGGREGATION_CYCLE_INTERVAL: int = 60
|
||
|
||
|
||
def get_cadence_for_source(source_type: str, config: Optional[dict[str, Any]]) -> int:
|
||
"""Return the polling interval for a source.
|
||
|
||
Uses the source's config.polling_interval_seconds if set,
|
||
otherwise falls back to the default cadence for the source type.
|
||
"""
|
||
if config and "polling_interval_seconds" in config:
|
||
try:
|
||
return max(10, int(config["polling_interval_seconds"]))
|
||
except (ValueError, TypeError):
|
||
pass
|
||
return DEFAULT_CADENCES.get(source_type, 600)
|
||
|
||
|
||
def compute_backoff(retry_count: int) -> int:
|
||
"""Exponential backoff with a cap. Returns seconds to wait."""
|
||
delay = DEFAULT_BACKOFF_BASE * (2 ** min(retry_count, 8))
|
||
return min(delay, MAX_BACKOFF)
|
||
|
||
|
||
def is_source_due(
|
||
source_type: str,
|
||
source_config: Optional[dict[str, Any]],
|
||
last_completed_at: Optional[datetime],
|
||
last_status: Optional[str],
|
||
retry_count: int,
|
||
next_retry_at: Optional[datetime],
|
||
now: datetime,
|
||
) -> bool:
|
||
"""Determine whether a source is due for its next polling cycle.
|
||
|
||
Checks:
|
||
- If the source has never run, it is due.
|
||
- If the last run failed and we have a next_retry_at in the future, skip.
|
||
- If the last run failed and retry_count exceeds max, skip (needs manual reset).
|
||
- Otherwise, check if enough time has elapsed since the last completed run.
|
||
"""
|
||
# Never run before — always due
|
||
if last_completed_at is None and last_status is None:
|
||
return True
|
||
|
||
# If last run failed, respect backoff
|
||
if last_status == "failed":
|
||
if retry_count >= MAX_RETRY_COUNT:
|
||
return False
|
||
if next_retry_at:
|
||
# Normalize tz-awareness to match 'now'
|
||
if now.tzinfo is not None and next_retry_at.tzinfo is None:
|
||
nra = next_retry_at.replace(tzinfo=timezone.utc)
|
||
elif now.tzinfo is None and next_retry_at.tzinfo is not None:
|
||
nra = next_retry_at.replace(tzinfo=None)
|
||
else:
|
||
nra = next_retry_at
|
||
if now < nra:
|
||
return False
|
||
# Backoff elapsed or no next_retry_at set — allow retry
|
||
return True
|
||
|
||
# If currently running, don't double-schedule
|
||
if last_status == "running":
|
||
return False
|
||
|
||
# Normal cadence check
|
||
if last_completed_at is None:
|
||
return True
|
||
|
||
cadence = get_cadence_for_source(source_type, source_config)
|
||
# Ensure both datetimes have matching tz-awareness for subtraction
|
||
if now.tzinfo is not None and last_completed_at.tzinfo is None:
|
||
last_completed_at = last_completed_at.replace(tzinfo=timezone.utc)
|
||
elif now.tzinfo is None and last_completed_at.tzinfo is not None:
|
||
last_completed_at = last_completed_at.replace(tzinfo=None)
|
||
elapsed = (now - last_completed_at).total_seconds()
|
||
return elapsed >= cadence
|
||
|
||
|
||
def build_job_payload(
|
||
source: Any,
|
||
aliases: list[str],
|
||
now: datetime,
|
||
) -> dict[str, Any]:
|
||
"""Build the ingestion job payload for a source."""
|
||
return {
|
||
"source_id": str(source["source_id"]),
|
||
"company_id": str(source["company_id"]) if source.get("company_id") else None,
|
||
"ticker": source.get("ticker") or "",
|
||
"legal_name": source.get("legal_name") or "",
|
||
"aliases": aliases,
|
||
"source_type": source["source_type"],
|
||
"source_name": source["source_name"],
|
||
"config": _ensure_dict(source["config"]) or {},
|
||
"credibility_score": float(source["credibility_score"]) if source["credibility_score"] else 0.5,
|
||
"scheduled_at": now.isoformat(),
|
||
}
|
||
|
||
|
||
async def acquire_lock(rds: aioredis.Redis, name: str, ttl: int = 60) -> bool:
|
||
"""Acquire a distributed lock. Returns True if acquired."""
|
||
return await rds.set(lock_key(name), "1", nx=True, ex=ttl)
|
||
|
||
|
||
async def release_lock(rds: aioredis.Redis, name: str) -> None:
|
||
"""Release a distributed lock."""
|
||
await rds.delete(lock_key(name))
|
||
|
||
|
||
async def check_rate_limit(
|
||
rds: aioredis.Redis,
|
||
source_type: str,
|
||
now: datetime,
|
||
max_per_minute: Optional[int] = None,
|
||
) -> bool:
|
||
"""Check whether the source type is within its rate limit window.
|
||
|
||
Enforces two limits:
|
||
1. Per-source-type limit (e.g. market_api: 20/min)
|
||
2. Global Polygon limit across all Polygon-backed types (45/min combined)
|
||
|
||
Returns True if the request is allowed, False if rate-limited.
|
||
"""
|
||
limit = max_per_minute or DEFAULT_RATE_LIMITS.get(source_type, 30)
|
||
window = now.strftime("%Y%m%d%H%M")
|
||
|
||
# Per-source-type check
|
||
key = rate_limit_key(source_type, window)
|
||
count = await rds.incr(key)
|
||
if count == 1:
|
||
await rds.expire(key, 120)
|
||
if count > limit:
|
||
return False
|
||
|
||
# Global Polygon check for source types that share the Polygon API key
|
||
if source_type in POLYGON_SOURCE_TYPES:
|
||
global_key = rate_limit_key("_polygon_global", window)
|
||
global_count = await rds.incr(global_key)
|
||
if global_count == 1:
|
||
await rds.expire(global_key, 120)
|
||
if global_count > POLYGON_GLOBAL_RATE_LIMIT:
|
||
# Roll back the per-type counter since we won't actually make the call
|
||
await rds.decr(key)
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||
"""Fetch all active company-specific sources joined with their active companies."""
|
||
return await pool.fetch(
|
||
"""SELECT s.id AS source_id,
|
||
s.company_id,
|
||
s.source_type,
|
||
s.source_name,
|
||
s.config,
|
||
s.credibility_score,
|
||
c.ticker,
|
||
c.legal_name
|
||
FROM sources s
|
||
JOIN companies c ON s.company_id = c.id
|
||
WHERE s.active = TRUE AND c.active = TRUE
|
||
AND s.source_type != 'macro_news'
|
||
ORDER BY s.source_type, c.ticker"""
|
||
)
|
||
|
||
|
||
async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||
"""Fetch all active macro news sources.
|
||
|
||
Macro sources are not company-specific — they have source_type='macro_news'
|
||
and may have company_id NULL. They are scheduled independently from
|
||
company-specific sources.
|
||
|
||
Requirements: 1.1
|
||
"""
|
||
return await pool.fetch(
|
||
"""SELECT s.id AS source_id,
|
||
s.company_id,
|
||
s.source_type,
|
||
s.source_name,
|
||
s.config,
|
||
s.credibility_score
|
||
FROM sources s
|
||
WHERE s.active = TRUE AND s.source_type = 'macro_news'
|
||
ORDER BY s.source_name"""
|
||
)
|
||
|
||
|
||
async def fetch_global_market_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||
"""Fetch active market sources that are not company-specific.
|
||
|
||
These are sources like the grouped daily endpoint that fetch data for
|
||
all tickers in a single API call. They have company_id IS NULL and
|
||
source_type = 'market_api'.
|
||
"""
|
||
return await pool.fetch(
|
||
"""SELECT s.id AS source_id,
|
||
s.company_id,
|
||
s.source_type,
|
||
s.source_name,
|
||
s.config,
|
||
s.credibility_score
|
||
FROM sources s
|
||
WHERE s.active = TRUE
|
||
AND s.source_type = 'market_api'
|
||
AND s.company_id IS NULL
|
||
ORDER BY s.source_name"""
|
||
)
|
||
|
||
|
||
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
|
||
"""Fetch all aliases for a company."""
|
||
rows = await pool.fetch(
|
||
"SELECT alias FROM company_aliases WHERE company_id = $1",
|
||
company_id,
|
||
)
|
||
return [r["alias"] for r in rows]
|
||
|
||
|
||
async def fetch_last_run(
|
||
pool: asyncpg.Pool, source_id: str
|
||
) -> Optional[asyncpg.Record]:
|
||
"""Fetch the most recent ingestion run for a source."""
|
||
return await pool.fetchrow(
|
||
"""SELECT status, started_at, completed_at, retry_count, next_retry_at
|
||
FROM ingestion_runs
|
||
WHERE source_id = $1
|
||
ORDER BY started_at DESC
|
||
LIMIT 1""",
|
||
source_id,
|
||
)
|
||
|
||
|
||
async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||
"""One scheduling pass: find due sources and enqueue ingestion jobs.
|
||
|
||
Returns the number of jobs enqueued.
|
||
"""
|
||
now = datetime.now(tz=timezone.utc)
|
||
sources = await fetch_active_sources(pool)
|
||
|
||
enqueued = 0
|
||
skipped_rate_limit = 0
|
||
skipped_not_due = 0
|
||
|
||
for src in sources:
|
||
source_id = src["source_id"]
|
||
source_type = src["source_type"]
|
||
source_config = _ensure_dict(src["config"])
|
||
|
||
# Check last run status and timing
|
||
last_run = await fetch_last_run(pool, source_id)
|
||
|
||
last_completed_at = None
|
||
last_status = None
|
||
retry_count = 0
|
||
next_retry_at = None
|
||
|
||
if last_run:
|
||
last_status = last_run["status"]
|
||
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
||
retry_count = last_run["retry_count"] or 0
|
||
next_retry_at = last_run["next_retry_at"]
|
||
|
||
if not is_source_due(
|
||
source_type=source_type,
|
||
source_config=source_config,
|
||
last_completed_at=last_completed_at,
|
||
last_status=last_status,
|
||
retry_count=retry_count,
|
||
next_retry_at=next_retry_at,
|
||
now=now,
|
||
):
|
||
skipped_not_due += 1
|
||
continue
|
||
|
||
# Rate limit check
|
||
if not await check_rate_limit(rds, source_type, now):
|
||
logger.warning(
|
||
"Rate limit hit for %s, skipping %s/%s",
|
||
source_type, src["ticker"], src["source_name"],
|
||
)
|
||
skipped_rate_limit += 1
|
||
continue
|
||
|
||
# Fetch company aliases for downstream entity matching
|
||
aliases = await fetch_aliases_for_company(pool, src["company_id"])
|
||
|
||
job = build_job_payload(src, aliases, now)
|
||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job)) # type: ignore[misc]
|
||
enqueued += 1
|
||
|
||
logger.debug(
|
||
"Enqueued %s job for %s (%s)",
|
||
source_type, src["ticker"], src["source_name"],
|
||
)
|
||
|
||
# --- Schedule macro news sources (Requirement 1.1) ---
|
||
macro_sources = await fetch_macro_sources(pool)
|
||
for src in macro_sources:
|
||
source_id = src["source_id"]
|
||
source_type = src["source_type"]
|
||
source_config = _ensure_dict(src["config"])
|
||
|
||
last_run = await fetch_last_run(pool, source_id)
|
||
|
||
last_completed_at = None
|
||
last_status = None
|
||
retry_count = 0
|
||
next_retry_at = None
|
||
|
||
if last_run:
|
||
last_status = last_run["status"]
|
||
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
||
retry_count = last_run["retry_count"] or 0
|
||
next_retry_at = last_run["next_retry_at"]
|
||
|
||
if not is_source_due(
|
||
source_type=source_type,
|
||
source_config=source_config,
|
||
last_completed_at=last_completed_at,
|
||
last_status=last_status,
|
||
retry_count=retry_count,
|
||
next_retry_at=next_retry_at,
|
||
now=now,
|
||
):
|
||
skipped_not_due += 1
|
||
continue
|
||
|
||
if not await check_rate_limit(rds, source_type, now):
|
||
logger.warning(
|
||
"Rate limit hit for macro_news, skipping %s",
|
||
src["source_name"],
|
||
)
|
||
skipped_rate_limit += 1
|
||
continue
|
||
|
||
job = build_job_payload(src, [], now)
|
||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
|
||
enqueued += 1
|
||
|
||
logger.debug(
|
||
"Enqueued macro_news job for %s", src["source_name"],
|
||
)
|
||
|
||
# --- Schedule global market sources (grouped daily, etc.) ---
|
||
global_market_sources = await fetch_global_market_sources(pool)
|
||
for src in global_market_sources:
|
||
source_id = src["source_id"]
|
||
source_type = src["source_type"]
|
||
source_config = _ensure_dict(src["config"])
|
||
|
||
# Use a longer cadence for grouped daily (once per day = 86400s)
|
||
endpoint = source_config.get("endpoint", "")
|
||
if endpoint == "grouped_daily":
|
||
source_config.setdefault("polling_interval_seconds", 86400)
|
||
|
||
last_run = await fetch_last_run(pool, source_id)
|
||
|
||
last_completed_at = None
|
||
last_status = None
|
||
retry_count = 0
|
||
next_retry_at = None
|
||
|
||
if last_run:
|
||
last_status = last_run["status"]
|
||
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
||
retry_count = last_run["retry_count"] or 0
|
||
next_retry_at = last_run["next_retry_at"]
|
||
|
||
if not is_source_due(
|
||
source_type=source_type,
|
||
source_config=source_config,
|
||
last_completed_at=last_completed_at,
|
||
last_status=last_status,
|
||
retry_count=retry_count,
|
||
next_retry_at=next_retry_at,
|
||
now=now,
|
||
):
|
||
skipped_not_due += 1
|
||
continue
|
||
|
||
if not await check_rate_limit(rds, source_type, now):
|
||
skipped_rate_limit += 1
|
||
continue
|
||
|
||
# Build job with ticker="_MARKET" for global sources
|
||
job = build_job_payload(src, [], now)
|
||
if endpoint == "intraday_bars":
|
||
# Expand intraday source into per-ticker jobs for all active companies
|
||
tickers = await pool.fetch(
|
||
"SELECT ticker FROM companies WHERE active = TRUE"
|
||
)
|
||
for t_row in tickers:
|
||
ticker_job = dict(job)
|
||
ticker_job["ticker"] = t_row["ticker"]
|
||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(ticker_job))
|
||
enqueued += 1
|
||
logger.info("Enqueued %d intraday bar jobs", len(tickers))
|
||
else:
|
||
job["ticker"] = "_MARKET"
|
||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
|
||
enqueued += 1
|
||
logger.info("Enqueued grouped daily market data job")
|
||
|
||
logger.info(
|
||
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
|
||
enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources) + len(global_market_sources),
|
||
)
|
||
return enqueued
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Report generation: queue consumer + scheduled triggers
|
||
# Requirements: 6.1, 6.2, 6.3, 6.4, 6.5
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# Eastern Time zone for market-close checks
|
||
_ET = ZoneInfo("America/New_York")
|
||
|
||
# How often to check the report generation queue (every N cycles)
|
||
# 15s tick × 4 cycles = ~1 minute
|
||
REPORT_CONSUMER_CYCLE_INTERVAL: int = 4
|
||
|
||
# How often to check report scheduling triggers (every N cycles)
|
||
# 15s tick × 20 cycles = ~5 minutes
|
||
REPORT_SCHEDULE_CYCLE_INTERVAL: int = 20
|
||
|
||
# Redis key prefix for report schedule dedup markers
|
||
_REPORT_DEDUPE_PREFIX = f"{QUEUE_PREFIX}:report_dedupe"
|
||
_REPORT_DEDUPE_TTL = 86400 # 24 hours — prevents re-enqueuing same report within a day
|
||
|
||
|
||
def _report_dedupe_key(report_type: str, period_start: str, period_end: str) -> str:
|
||
"""Build a Redis key for deduplicating report schedule triggers."""
|
||
return f"{_REPORT_DEDUPE_PREFIX}:{report_type}:{period_start}:{period_end}"
|
||
|
||
|
||
async def consume_report_generation_jobs(
|
||
pool: asyncpg.Pool,
|
||
rds: aioredis.Redis,
|
||
) -> int:
|
||
"""Pop and process jobs from the report generation queue.
|
||
|
||
Pops up to 5 jobs per invocation to avoid blocking the scheduler loop.
|
||
Each job is deserialized and handed to process_report_job from the
|
||
reporting generator module.
|
||
|
||
Returns the number of jobs processed.
|
||
"""
|
||
from services.reporting.generator import process_report_job
|
||
|
||
report_queue = queue_key(QUEUE_REPORT_GENERATION)
|
||
processed = 0
|
||
|
||
for _ in range(5):
|
||
raw = await rds.lpop(report_queue)
|
||
if raw is None:
|
||
break
|
||
|
||
try:
|
||
job = json.loads(raw)
|
||
except (json.JSONDecodeError, TypeError):
|
||
logger.error("Invalid report generation job payload: %s", raw)
|
||
continue
|
||
|
||
logger.info(
|
||
"Processing report generation job: type=%s period=%s to %s",
|
||
job.get("report_type"),
|
||
job.get("period_start"),
|
||
job.get("period_end"),
|
||
)
|
||
|
||
try:
|
||
await process_report_job(pool, job)
|
||
processed += 1
|
||
except Exception:
|
||
logger.exception(
|
||
"Failed to process report generation job: %s", job,
|
||
)
|
||
|
||
if processed > 0:
|
||
logger.info("Processed %d report generation jobs", processed)
|
||
return processed
|
||
|
||
|
||
async def maybe_enqueue_daily_report(
|
||
rds: aioredis.Redis,
|
||
now_et: datetime,
|
||
) -> bool:
|
||
"""Enqueue a daily report job if it's after 16:30 ET on a weekday.
|
||
|
||
Uses a Redis dedupe key to avoid re-enqueuing the same daily report.
|
||
Returns True if a job was enqueued, False otherwise.
|
||
"""
|
||
# Only on weekdays (Mon=0 .. Fri=4)
|
||
if now_et.weekday() > 4:
|
||
return False
|
||
|
||
# Only after 16:30 ET
|
||
if now_et.hour < 16 or (now_et.hour == 16 and now_et.minute < 30):
|
||
return False
|
||
|
||
today = now_et.date()
|
||
period_start = today.isoformat()
|
||
period_end = today.isoformat()
|
||
|
||
dedupe = _report_dedupe_key("daily", period_start, period_end)
|
||
created = await rds.set(dedupe, "1", nx=True, ex=_REPORT_DEDUPE_TTL)
|
||
if not created:
|
||
return False
|
||
|
||
job = json.dumps({
|
||
"report_type": "daily",
|
||
"period_start": period_start,
|
||
"period_end": period_end,
|
||
})
|
||
await rds.rpush(queue_key(QUEUE_REPORT_GENERATION), job)
|
||
logger.info("Enqueued daily report for %s", period_start)
|
||
return True
|
||
|
||
|
||
async def maybe_enqueue_weekly_report(
|
||
rds: aioredis.Redis,
|
||
now_et: datetime,
|
||
) -> bool:
|
||
"""Enqueue a weekly report job on Saturday.
|
||
|
||
Covers the previous Monday through Friday.
|
||
Uses a Redis dedupe key to avoid re-enqueuing the same weekly report.
|
||
Returns True if a job was enqueued, False otherwise.
|
||
"""
|
||
# Only on Saturday (weekday() == 5)
|
||
if now_et.weekday() != 5:
|
||
return False
|
||
|
||
today = now_et.date()
|
||
# Previous Monday = today - 5 days, previous Friday = today - 1 day
|
||
period_start = (today - timedelta(days=5)).isoformat()
|
||
period_end = (today - timedelta(days=1)).isoformat()
|
||
|
||
dedupe = _report_dedupe_key("weekly", period_start, period_end)
|
||
created = await rds.set(dedupe, "1", nx=True, ex=_REPORT_DEDUPE_TTL)
|
||
if not created:
|
||
return False
|
||
|
||
job = json.dumps({
|
||
"report_type": "weekly",
|
||
"period_start": period_start,
|
||
"period_end": period_end,
|
||
})
|
||
await rds.rpush(queue_key(QUEUE_REPORT_GENERATION), job)
|
||
logger.info(
|
||
"Enqueued weekly report for %s to %s", period_start, period_end,
|
||
)
|
||
return True
|
||
|
||
|
||
async def check_report_schedule(rds: aioredis.Redis) -> None:
|
||
"""Check if daily or weekly report triggers should fire.
|
||
|
||
Called periodically from the main loop. Uses Eastern Time to determine
|
||
market close (16:30 ET) and day of week.
|
||
"""
|
||
now_et = datetime.now(tz=_ET)
|
||
await maybe_enqueue_daily_report(rds, now_et)
|
||
await maybe_enqueue_weekly_report(rds, now_et)
|
||
|
||
|
||
async def enqueue_periodic_aggregation(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||
"""Enqueue aggregation jobs for all active tickers.
|
||
|
||
Runs periodically to ensure trend data stays fresh even when no new
|
||
documents are being ingested. During market hours this runs every ~15
|
||
minutes; outside market hours it runs every ~60 minutes (for backtesting
|
||
data continuity).
|
||
"""
|
||
# Fetch all active tickers
|
||
rows = await pool.fetch(
|
||
"SELECT ticker FROM companies WHERE active = TRUE ORDER BY ticker"
|
||
)
|
||
if not rows:
|
||
return 0
|
||
|
||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||
count = 0
|
||
for row in rows:
|
||
await rds.rpush(agg_queue, json.dumps({"ticker": row["ticker"]}))
|
||
count += 1
|
||
|
||
logger.info("Periodic aggregation: enqueued %d tickers for re-aggregation", count)
|
||
return count
|
||
|
||
|
||
async def main() -> None:
|
||
config = load_config()
|
||
setup_logging("scheduler", level=config.log_level, json_output=config.json_logs)
|
||
|
||
pool = await get_pg_pool(config)
|
||
rds = get_redis(config)
|
||
|
||
logger.info("Scheduler started (tick=%ds)", SCHEDULER_TICK)
|
||
pipeline_key = PIPELINE_ENABLED_KEY
|
||
|
||
# If PIPELINE_DEFAULT_OFF is set, initialize the toggle to OFF on first boot
|
||
# (only if the key doesn't already exist — preserves manual overrides)
|
||
if os.getenv("PIPELINE_DEFAULT_OFF", "").lower() in ("1", "true", "yes"):
|
||
created = await rds.set(pipeline_key, "0", nx=True)
|
||
if created:
|
||
logger.info("Pipeline toggle initialized to OFF (PIPELINE_DEFAULT_OFF=true)")
|
||
|
||
recovery_counter = 0
|
||
retry_counter = 0
|
||
cleanup_counter = 0
|
||
aggregation_counter = 0
|
||
report_consumer_counter = 0
|
||
report_schedule_counter = 0
|
||
try:
|
||
while True:
|
||
try:
|
||
# Check pipeline toggle — skip cycle if disabled
|
||
flag = await rds.get(pipeline_key)
|
||
if flag == "0":
|
||
await asyncio.sleep(SCHEDULER_TICK)
|
||
continue
|
||
|
||
if await acquire_lock(rds, "scheduler_cycle", ttl=30):
|
||
try:
|
||
await schedule_cycle(pool, rds)
|
||
# Run stale document recovery every ~20 cycles (~5 minutes)
|
||
recovery_counter += 1
|
||
if recovery_counter >= 20:
|
||
recovery_counter = 0
|
||
await recover_stale_documents(pool, rds)
|
||
# Retry extraction failures every ~40 cycles (~10 minutes)
|
||
retry_counter += 1
|
||
if retry_counter >= 40:
|
||
retry_counter = 0
|
||
await retry_failed_extractions(pool, rds)
|
||
# Run signal cleanup periodically (~25 minutes)
|
||
cleanup_counter += 1
|
||
if cleanup_counter >= CLEANUP_CYCLE_INTERVAL:
|
||
cleanup_counter = 0
|
||
await cleanup_all_tables(pool)
|
||
# Periodic aggregation during market hours (~15 minutes)
|
||
aggregation_counter += 1
|
||
if aggregation_counter >= AGGREGATION_CYCLE_INTERVAL:
|
||
aggregation_counter = 0
|
||
await enqueue_periodic_aggregation(pool, rds)
|
||
# Consume report generation jobs (~1 minute)
|
||
report_consumer_counter += 1
|
||
if report_consumer_counter >= REPORT_CONSUMER_CYCLE_INTERVAL:
|
||
report_consumer_counter = 0
|
||
await consume_report_generation_jobs(pool, rds)
|
||
# Check report schedule triggers (~5 minutes)
|
||
report_schedule_counter += 1
|
||
if report_schedule_counter >= REPORT_SCHEDULE_CYCLE_INTERVAL:
|
||
report_schedule_counter = 0
|
||
await check_report_schedule(rds)
|
||
finally:
|
||
await release_lock(rds, "scheduler_cycle")
|
||
except Exception:
|
||
logger.exception("Scheduler cycle error")
|
||
await asyncio.sleep(SCHEDULER_TICK)
|
||
finally:
|
||
await pool.close()
|
||
await rds.close()
|
||
|
||
|
||
# How long a document can sit in "parsed" before we consider it orphaned
|
||
# Must be longer than the expected queue drain time to avoid re-enqueuing
|
||
# docs that are already queued but not yet processed.
|
||
STALE_PARSED_THRESHOLD_MINUTES: int = 240
|
||
|
||
# How long after an extraction failure before we retry
|
||
EXTRACTION_FAILED_RETRY_MINUTES: int = 60
|
||
|
||
# Redis set key for tracking enqueued doc IDs (prevents duplicate enqueuing)
|
||
_ENQUEUED_SET = f"{QUEUE_PREFIX}:enqueued"
|
||
# How long an enqueued marker lives before it can be re-enqueued (seconds)
|
||
_ENQUEUED_TTL = 14400 # 4 hours — matches STALE_PARSED_THRESHOLD_MINUTES
|
||
|
||
|
||
async def _enqueue_if_new(
|
||
rds: aioredis.Redis,
|
||
queue: str,
|
||
document_id: str,
|
||
ticker: str,
|
||
) -> bool:
|
||
"""Push a job onto *queue* only if *document_id* isn't already tracked.
|
||
|
||
Uses a Redis SET with per-member expiry (via a separate key) to prevent
|
||
the same document from being enqueued multiple times by recovery sweeps.
|
||
Returns True if enqueued, False if skipped as duplicate.
|
||
"""
|
||
marker_key = f"{_ENQUEUED_SET}:{document_id}"
|
||
# SET NX returns True only if the key was created (not already present)
|
||
added = await rds.set(marker_key, "1", nx=True, ex=_ENQUEUED_TTL)
|
||
if not added:
|
||
return False
|
||
await rds.rpush(queue, json.dumps({
|
||
"document_id": document_id,
|
||
"ticker": ticker,
|
||
}))
|
||
return True
|
||
|
||
|
||
async def recover_stale_documents(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||
"""Re-enqueue documents stuck in 'parsed' status for extraction.
|
||
|
||
Documents can get orphaned when Redis loses queue entries (pod restart,
|
||
OOM, etc.). This sweep catches any document that has been in 'parsed'
|
||
status for longer than STALE_PARSED_THRESHOLD_MINUTES and re-enqueues
|
||
it for extraction.
|
||
|
||
Returns the number of documents re-enqueued.
|
||
"""
|
||
rows = await pool.fetch(
|
||
"""SELECT d.id, d.document_type, dcm.ticker
|
||
FROM documents d
|
||
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
|
||
WHERE d.status = 'parsed'
|
||
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
|
||
AND NOT EXISTS (
|
||
SELECT 1 FROM global_events ge WHERE ge.source_document_id = d.id
|
||
)
|
||
ORDER BY d.created_at ASC
|
||
LIMIT 100""",
|
||
STALE_PARSED_THRESHOLD_MINUTES,
|
||
)
|
||
|
||
if not rows:
|
||
return 0
|
||
|
||
enqueued = 0
|
||
doc_ids = []
|
||
for row in rows:
|
||
doc_type = row["document_type"]
|
||
if doc_type == "macro_event":
|
||
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||
else:
|
||
target = queue_key(QUEUE_EXTRACTION)
|
||
|
||
added = await _enqueue_if_new(rds, target, str(row["id"]), row["ticker"] or "")
|
||
if added:
|
||
doc_ids.append(row["id"])
|
||
enqueued += 1
|
||
|
||
# Touch updated_at so these docs won't be re-enqueued until the threshold passes again
|
||
if doc_ids:
|
||
await pool.execute(
|
||
"UPDATE documents SET updated_at = NOW() WHERE id = ANY($1::uuid[])",
|
||
doc_ids,
|
||
)
|
||
|
||
logger.info("Recovered %d stale parsed documents for extraction", enqueued)
|
||
return enqueued
|
||
|
||
|
||
async def retry_failed_extractions(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||
"""Re-enqueue documents stuck in 'extraction_failed' for another attempt.
|
||
|
||
Resets status to 'parsed', deletes the failed intelligence row so the
|
||
extractor treats them as fresh, and pushes them onto the extraction queue.
|
||
|
||
Only retries documents whose last attempt was at least
|
||
EXTRACTION_FAILED_RETRY_MINUTES ago to avoid tight retry loops.
|
||
|
||
Returns the number of documents re-enqueued.
|
||
"""
|
||
rows = await pool.fetch(
|
||
"""SELECT d.id, d.document_type, dcm.ticker
|
||
FROM documents d
|
||
LEFT JOIN document_company_mentions dcm ON d.id = dcm.document_id
|
||
WHERE d.status = 'extraction_failed'
|
||
AND d.updated_at < NOW() - INTERVAL '1 minute' * $1
|
||
ORDER BY d.updated_at ASC
|
||
LIMIT 100""",
|
||
EXTRACTION_FAILED_RETRY_MINUTES,
|
||
)
|
||
|
||
if not rows:
|
||
return 0
|
||
|
||
enqueued = 0
|
||
doc_ids = []
|
||
for row in rows:
|
||
doc_type = row["document_type"]
|
||
if doc_type == "macro_event":
|
||
target = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||
else:
|
||
target = queue_key(QUEUE_EXTRACTION)
|
||
|
||
added = await _enqueue_if_new(rds, target, str(row["id"]), row["ticker"] or "")
|
||
if added:
|
||
doc_ids.append(row["id"])
|
||
enqueued += 1
|
||
|
||
if doc_ids:
|
||
# Delete failed intelligence rows so extractor starts fresh
|
||
await pool.execute(
|
||
"""DELETE FROM document_intelligence
|
||
WHERE document_id = ANY($1::uuid[])
|
||
AND validation_status = 'failed'""",
|
||
doc_ids,
|
||
)
|
||
# Reset status to 'parsed' and touch updated_at
|
||
await pool.execute(
|
||
"""UPDATE documents
|
||
SET status = 'parsed', updated_at = NOW()
|
||
WHERE id = ANY($1::uuid[])""",
|
||
doc_ids,
|
||
)
|
||
|
||
logger.info("Retried %d extraction-failed documents", enqueued)
|
||
return enqueued
|
||
|
||
|
||
# How often to run competitive signal cleanup (every ~100 cycles = ~25 minutes)
|
||
CLEANUP_CYCLE_INTERVAL: int = 100
|
||
# Keep competitive signals for this many days
|
||
COMPETITIVE_SIGNAL_RETENTION_DAYS: int = 30
|
||
|
||
|
||
async def cleanup_old_signals(pool: asyncpg.Pool) -> int:
|
||
"""Delete competitive signal records older than the retention window.
|
||
|
||
Prevents the competitive_signal_records table from growing unbounded.
|
||
Returns the number of rows deleted.
|
||
"""
|
||
result = await pool.execute(
|
||
"DELETE FROM competitive_signal_records WHERE computed_at < NOW() - INTERVAL '1 day' * $1",
|
||
COMPETITIVE_SIGNAL_RETENTION_DAYS,
|
||
)
|
||
# result is like "DELETE 1234"
|
||
count = int(result.split()[-1]) if result else 0
|
||
if count > 0:
|
||
logger.info("Cleaned up %d old competitive signal records", count)
|
||
return count
|
||
|
||
|
||
async def cleanup_all_tables(pool: asyncpg.Pool) -> None:
|
||
"""Run retention cleanup across all tables that grow unbounded.
|
||
|
||
Called periodically by the scheduler (~every 25 minutes).
|
||
Each table has its own retention window.
|
||
"""
|
||
cleanups = [
|
||
# (table, time_column, retention_days, description)
|
||
("competitive_signal_records", "computed_at", 30, "competitive signals"),
|
||
("ingestion_runs", "started_at", 14, "ingestion runs"),
|
||
("trading_decisions", "created_at", 90, "trading decisions"),
|
||
("risk_evaluations", "evaluated_at", 30, "risk evaluations"),
|
||
("audit_events", "created_at", 90, "audit events"),
|
||
("macro_impact_records", "computed_at", 60, "macro impact records"),
|
||
("recommendation_evidence", "created_at", 60, "recommendation evidence"),
|
||
("recommendations", "generated_at", 30, "old recommendations"),
|
||
("order_events", "created_at", 90, "order events"),
|
||
("model_performance_metrics", "recorded_at", 30, "model metrics"),
|
||
]
|
||
|
||
total = 0
|
||
for table, col, days, desc in cleanups:
|
||
try:
|
||
result = await pool.execute(
|
||
f"DELETE FROM {table} WHERE {col} < NOW() - INTERVAL '1 day' * $1", # noqa: S608
|
||
days,
|
||
)
|
||
count = int(result.split()[-1]) if result else 0
|
||
if count > 0:
|
||
total += count
|
||
logger.info("Cleaned %d %s (>%dd old)", count, desc, days)
|
||
except Exception:
|
||
# Table might not exist or column name might differ — skip
|
||
pass
|
||
|
||
if total > 0:
|
||
logger.info("Total cleanup: %d rows deleted across all tables", total)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|