"""Outcome Evaluator — matches predictions with realized market outcomes. Runs periodically to evaluate prediction snapshots whose horizon has elapsed. For each snapshot, fetches future prices at the horizon endpoint and computes returns, excess returns, directional accuracy, and profitability across all five evaluation horizons (1h, 6h, 1d, 7d, 30d). Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 4.10 """ from __future__ import annotations import json import logging import uuid from dataclasses import dataclass, field from datetime import datetime, timedelta import asyncpg logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- HORIZON_DURATIONS: dict[str, timedelta] = { "1h": timedelta(hours=1), "6h": timedelta(hours=6), "1d": timedelta(days=1), "7d": timedelta(days=7), "30d": timedelta(days=30), } # --------------------------------------------------------------------------- # Dataclasses # --------------------------------------------------------------------------- @dataclass class PredictionOutcome: """Realized outcome for a prediction at a specific horizon.""" id: str # UUID prediction_id: str evaluated_at: datetime horizon: str # 1h, 6h, 1d, 7d, 30d future_price: float future_return: float spy_future_price: float | None spy_return: float | None sector_etf_future_price: float | None sector_etf_return: float | None excess_return_vs_spy: float | None excess_return_vs_sector: float | None direction_correct: bool profitable: bool metadata: dict = field(default_factory=dict) # --------------------------------------------------------------------------- # SQL statements # --------------------------------------------------------------------------- # Find matured predictions: snapshots where generated_at + horizon_duration <= NOW() # and no outcome has been recorded yet for that (prediction_id, horizon) pair. # We evaluate ALL 5 horizons for each snapshot, not just the snapshot's own horizon. _MATURED_PREDICTIONS_SQL = """ SELECT ps.id, ps.generated_at, ps.ticker, ps.horizon AS snapshot_horizon, ps.direction, ps.action, ps.price_at_prediction, ps.spy_price_at_prediction, ps.sector_etf_price_at_prediction FROM prediction_snapshots ps WHERE ps.generated_at + $1::interval <= NOW() AND NOT EXISTS ( SELECT 1 FROM prediction_outcomes po WHERE po.prediction_id = ps.id AND po.horizon = $2 ) """ # Fetch the close price for a ticker at or before a specific time. # Uses the closest bar before or at the target time. _CLOSE_AT_TIME_SQL = """ SELECT (data->>'c')::float AS close FROM market_snapshots WHERE ticker = $1 AND snapshot_type = 'bar' AND data->>'c' IS NOT NULL AND captured_at <= $2 ORDER BY captured_at DESC LIMIT 1 """ _INSERT_OUTCOME_SQL = """ INSERT INTO prediction_outcomes ( id, prediction_id, evaluated_at, horizon, future_price, future_return, spy_future_price, spy_return, sector_etf_future_price, sector_etf_return, excess_return_vs_spy, excess_return_vs_sector, direction_correct, profitable, metadata ) VALUES ( $1::uuid, $2::uuid, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15::jsonb ) """ # --------------------------------------------------------------------------- # Price fetching at a specific time # --------------------------------------------------------------------------- async def _fetch_close_at_time( pool: asyncpg.Pool, ticker: str, target_time: datetime, ) -> float | None: """Fetch the close price for a ticker at or before a specific time. Returns None if no market data is available before the target time. """ row = await pool.fetchrow(_CLOSE_AT_TIME_SQL, ticker, target_time) if row is None: return None return row["close"] # --------------------------------------------------------------------------- # Sector ETF lookup (reuse pattern from prediction_snapshot) # --------------------------------------------------------------------------- _SECTOR_ETF_MAP: dict[str, str] = { "Technology": "XLK", "Consumer Cyclical": "XLY", "Financial Services": "XLF", "Healthcare": "XLV", "Energy": "XLE", "Communication Services": "XLC", "Industrials": "XLI", "Consumer Defensive": "XLP", "Real Estate": "XLRE", "Utilities": "XLU", } _COMPANY_SECTOR_SQL = """ SELECT sector FROM companies WHERE ticker = $1 AND active = TRUE LIMIT 1 """ async def _fetch_sector_etf_ticker(pool: asyncpg.Pool, ticker: str) -> str | None: """Look up the sector ETF ticker for a company ticker.""" row = await pool.fetchrow(_COMPANY_SECTOR_SQL, ticker) if row is None or row["sector"] is None: return None return _SECTOR_ETF_MAP.get(row["sector"]) # --------------------------------------------------------------------------- # Return computation helpers # --------------------------------------------------------------------------- def _compute_return(current_price: float, future_price: float) -> float: """Compute simple return: (future - current) / current.""" if current_price == 0.0: return 0.0 return (future_price - current_price) / current_price def _is_direction_correct(direction: str, future_return: float) -> bool: """Determine if the predicted direction matches the realized return. bullish + positive return = True bearish + negative return = True All other combinations = False """ direction_lower = direction.lower() if direction_lower == "bullish" and future_return > 0.0: return True if direction_lower == "bearish" and future_return < 0.0: return True return False def _is_profitable(action: str, future_return: float) -> bool: """Determine if the predicted action would have been profitable. buy + positive return = True sell + negative return = True All other combinations = False """ action_lower = action.lower() if action_lower == "buy" and future_return > 0.0: return True if action_lower == "sell" and future_return < 0.0: return True return False # --------------------------------------------------------------------------- # Single prediction evaluation (Requirements 4.2–4.7) # --------------------------------------------------------------------------- async def evaluate_single_prediction( pool: asyncpg.Pool, snapshot: dict, horizon: str, ) -> PredictionOutcome | None: """Evaluate a single prediction at a specific horizon. Fetches the future price at generated_at + horizon_duration for the ticker, SPY, and sector ETF. Computes returns, excess returns, direction correctness, and profitability. Returns None if the ticker's future price is unavailable (Requirement 4.10). """ duration = HORIZON_DURATIONS[horizon] target_time = snapshot["generated_at"] + duration ticker = snapshot["ticker"] # Fetch future price for the ticker — required (skip if unavailable) future_price = await _fetch_close_at_time(pool, ticker, target_time) if future_price is None: logger.debug( "Future price unavailable for %s at horizon %s (target %s), skipping", ticker, horizon, target_time, ) return None price_at_prediction = snapshot["price_at_prediction"] if price_at_prediction is None or price_at_prediction == 0.0: logger.warning( "Price at prediction is NULL or zero for snapshot %s, skipping horizon %s", snapshot["id"], horizon, ) return None # Compute ticker future return (Requirement 4.2) future_return = _compute_return(price_at_prediction, future_price) # Fetch SPY future price and compute SPY return (Requirement 4.3) spy_future_price: float | None = None spy_return: float | None = None spy_price_at_prediction = snapshot["spy_price_at_prediction"] if spy_price_at_prediction is not None and spy_price_at_prediction != 0.0: spy_future_price = await _fetch_close_at_time(pool, "SPY", target_time) if spy_future_price is not None: spy_return = _compute_return(spy_price_at_prediction, spy_future_price) # Fetch sector ETF future price and compute sector return (Requirement 4.4) sector_etf_future_price: float | None = None sector_etf_return: float | None = None sector_etf_price_at_prediction = snapshot["sector_etf_price_at_prediction"] if ( sector_etf_price_at_prediction is not None and sector_etf_price_at_prediction != 0.0 ): sector_etf_ticker = await _fetch_sector_etf_ticker(pool, ticker) if sector_etf_ticker is not None: sector_etf_future_price = await _fetch_close_at_time( pool, sector_etf_ticker, target_time ) if sector_etf_future_price is not None: sector_etf_return = _compute_return( sector_etf_price_at_prediction, sector_etf_future_price ) # Compute excess returns (Requirement 4.5) excess_return_vs_spy: float | None = None if future_return is not None and spy_return is not None: excess_return_vs_spy = future_return - spy_return excess_return_vs_sector: float | None = None if future_return is not None and sector_etf_return is not None: excess_return_vs_sector = future_return - sector_etf_return # Determine direction correctness (Requirement 4.6) direction_correct = _is_direction_correct(snapshot["direction"], future_return) # Determine profitability (Requirement 4.7) profitable = _is_profitable(snapshot["action"], future_return) now = datetime.now().astimezone() return PredictionOutcome( id=str(uuid.uuid4()), prediction_id=str(snapshot["id"]), evaluated_at=now, horizon=horizon, future_price=future_price, future_return=future_return, spy_future_price=spy_future_price, spy_return=spy_return, sector_etf_future_price=sector_etf_future_price, sector_etf_return=sector_etf_return, excess_return_vs_spy=excess_return_vs_spy, excess_return_vs_sector=excess_return_vs_sector, direction_correct=direction_correct, profitable=profitable, metadata={ "ticker": ticker, "horizon": horizon, "price_at_prediction": price_at_prediction, "future_price": future_price, }, ) # --------------------------------------------------------------------------- # Store outcome (Requirement 4.9) # --------------------------------------------------------------------------- async def _store_outcome( conn: asyncpg.Connection, outcome: PredictionOutcome, ) -> None: """Persist a single prediction outcome to the database.""" await conn.execute( _INSERT_OUTCOME_SQL, outcome.id, outcome.prediction_id, outcome.evaluated_at, outcome.horizon, outcome.future_price, outcome.future_return, outcome.spy_future_price, outcome.spy_return, outcome.sector_etf_future_price, outcome.sector_etf_return, outcome.excess_return_vs_spy, outcome.excess_return_vs_sector, outcome.direction_correct, outcome.profitable, json.dumps(outcome.metadata), ) # --------------------------------------------------------------------------- # Main entry point (Requirements 4.1, 4.8, 4.9, 4.10) # --------------------------------------------------------------------------- async def evaluate_matured_predictions( pool: asyncpg.Pool, ) -> int: """Evaluate all matured prediction snapshots across all horizons. For each of the 5 horizons (1h, 6h, 1d, 7d, 30d), finds prediction snapshots where generated_at + horizon_duration <= NOW() and no outcome has been recorded for that (prediction_id, horizon) pair. For each matured snapshot-horizon pair, fetches future prices and computes returns. Skips horizons where the future price is unavailable — those will be retried on the next run (Requirement 4.10). Returns the total count of outcomes recorded. """ total_recorded = 0 for horizon, duration in HORIZON_DURATIONS.items(): # Find snapshots matured for this horizon rows = await pool.fetch(_MATURED_PREDICTIONS_SQL, duration, horizon) if not rows: continue logger.info( "Found %d matured predictions for horizon %s", len(rows), horizon ) for row in rows: snapshot = dict(row) try: outcome = await evaluate_single_prediction(pool, snapshot, horizon) if outcome is None: # Future price unavailable — skip, retry next run continue async with pool.acquire() as conn: async with conn.transaction(): await _store_outcome(conn, outcome) total_recorded += 1 except Exception: logger.exception( "Failed to evaluate snapshot %s at horizon %s", snapshot["id"], horizon, ) continue logger.info("Outcome evaluation complete: %d outcomes recorded", total_recorded) return total_recorded