feat: model validation, calibration, and signal quality layer
ci/woodpecker/push/test Pipeline failed
ci/woodpecker/push/build-1 unknown status
ci/woodpecker/push/build-3 unknown status
ci/woodpecker/push/build-2 unknown status
ci/woodpecker/push/finalize unknown status
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled

- Migration 035: prediction_snapshots, prediction_outcomes, signal_evidence_links, model_metric_snapshots tables + SQL views
- Prediction snapshot writer with canonical evidence keys, duplicate detection, contribution scores
- Outcome evaluator across 5 horizons (1h, 6h, 1d, 7d, 30d)
- Metrics engine: ECE, Brier score, IC, Rank IC, benchmark comparison
- Attribution engine: per-source, per-catalyst, per-layer performance
- Calibration engine: Bayesian shrinkage source reliability
- Quality gate for live trading eligibility with configurable thresholds
- 7 new /api/validation/* endpoints
- Upgraded OpsModel dashboard with validation tab
- Enhanced recommendation display with calibration context
- Backtest replay validation mode
- 86 Python tests (unit + property-based), 179 frontend tests passing
This commit is contained in:
Celes Renata
2026-05-01 03:04:58 +00:00
parent 5d2ffd9163
commit 7fcc8a6c07
23 changed files with 7554 additions and 9 deletions
+40 -1
View File
@@ -64,6 +64,7 @@ from services.shared.metrics import (
AGGREGATION_WINDOWS_COMPUTED,
)
from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow
from services.trading.model_quality_gate import QualityGateResult, evaluate_quality_gate
logger = logging.getLogger(__name__)
@@ -1576,10 +1577,34 @@ async def aggregate_company(
# Mid-cycle changes take effect on the next cycle.
probabilistic = await fetch_probabilistic_scoring_enabled(pool)
pipeline_mode = "probabilistic" if probabilistic else "heuristic"
# --- Quality gate evaluation (Req 11.2, 11.3) ---
# Evaluate model quality gate at the start of each aggregation cycle.
# When the gate fails, all recommendations are forced to paper mode.
# Gate evaluation failure defaults to paper-only (fail-safe).
quality_gate_passed = False
try:
gate_result: QualityGateResult = await evaluate_quality_gate(pool)
quality_gate_passed = gate_result.passed
logger.info(
"Quality gate for %s cycle: %s%s",
ticker,
"PASS" if gate_result.passed else "FAIL",
gate_result.reason,
)
except Exception:
logger.exception(
"Quality gate evaluation failed for %s cycle — "
"defaulting to paper-only mode (fail-safe)",
ticker,
)
quality_gate_passed = False
logger.info(
"Aggregation cycle for %s: pipeline_mode=%s",
"Aggregation cycle for %s: pipeline_mode=%s quality_gate=%s",
ticker,
pipeline_mode,
"passed" if quality_gate_passed else "failed",
)
# --- Regime detection (Req 7.1, 7.2, 7.3, 7.8, 7.9) ---
@@ -1647,6 +1672,20 @@ async def aggregate_company(
ticker_returns=ticker_returns,
ticker_volumes=ticker_volumes,
)
# When quality gate fails, annotate the trend summary so the
# recommendation engine forces paper mode (Req 11.2, 11.3).
if not quality_gate_passed:
ctx = summary.market_context
if isinstance(ctx, dict):
ctx["quality_gate_passed"] = False
elif ctx is not None and hasattr(ctx, "model_dump"):
ctx_dict = ctx.model_dump()
ctx_dict["quality_gate_passed"] = False
summary.market_context = ctx_dict
else:
summary.market_context = {"quality_gate_passed": False}
summaries.append(summary)
return summaries
+338
View File
@@ -43,6 +43,11 @@ from services.shared.db import get_pg_pool, get_redis
from services.shared.logging import new_trace_id, set_trace_context, setup_logging
from services.shared.redis_keys import PIPELINE_ENABLED_KEY, QUEUE_BROKER, QUEUE_PREFIX, queue_key
from services.shared.schemas import MAJOR_DECISION_CATALYSTS
from services.validation.attribution import (
compute_catalyst_attribution,
compute_layer_attribution,
compute_source_attribution,
)
logger = logging.getLogger("query_api")
@@ -3769,3 +3774,336 @@ async def get_variant_performance_history(
agent_id, variant_id, hours,
)
return [_row_to_dict(r) for r in rows]
# ---------------------------------------------------------------------------
# Model Validation Dashboard (Requirements 12.1, 12.2, 12.3, 12.7)
# ---------------------------------------------------------------------------
_VALID_LOOKBACKS = {"7d", "30d", "90d", "all"}
_VALID_HORIZONS = {"1h", "6h", "1d", "7d", "30d"}
@app.get("/api/validation/summary")
async def get_validation_summary(
lookback: str = Query(default="30d"),
horizon: str = Query(default="7d"),
):
"""Latest model metric snapshot plus quality gate status.
Returns the most recent model_metric_snapshot for the given
lookback/horizon combination, along with the current gate status
from risk_configs.
Requirement 12.1
"""
if lookback not in _VALID_LOOKBACKS:
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
if horizon not in _VALID_HORIZONS:
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
# Latest metric snapshot for the requested lookback/horizon
snapshot_row = await pool.fetchrow(
"""SELECT id, generated_at, lookback_window, horizon,
prediction_count, win_rate, directional_accuracy,
information_coefficient, rank_information_coefficient,
avg_return, avg_excess_return_vs_spy, avg_excess_return_vs_sector,
calibration_error, brier_score,
buy_win_rate, sell_win_rate, hold_win_rate,
metadata
FROM model_metric_snapshots
WHERE lookback_window = $1 AND horizon = $2
ORDER BY generated_at DESC
LIMIT 1""",
lookback, horizon,
)
snapshot = None
if snapshot_row:
snapshot = _row_to_dict(snapshot_row)
snapshot["metadata"] = _parse_jsonb(snapshot.get("metadata"))
# Gate status from risk_configs
gate_row = await pool.fetchrow(
"SELECT config, updated_at FROM risk_configs WHERE name = 'model_quality_gate'",
)
gate_status = None
if gate_row:
gate_status = _parse_jsonb(gate_row["config"])
return {
"snapshot": snapshot,
"gate_status": gate_status,
}
@app.get("/api/validation/calibration")
async def get_validation_calibration(
lookback: str = Query(default="30d"),
horizon: str = Query(default="7d"),
):
"""Calibration table with confidence buckets.
Queries v_prediction_performance for the given lookback/horizon,
groups by confidence buckets, and computes avg_confidence,
observed_win_rate, count, and miscalibrated flag per bucket.
Requirement 12.2
"""
if lookback not in _VALID_LOOKBACKS:
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
if horizon not in _VALID_HORIZONS:
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
# Build lookback filter
lookback_condition = ""
params: list[Any] = [horizon]
idx = 2
if lookback != "all":
lookback_days = {"7d": 7, "30d": 30, "90d": 90}[lookback]
lookback_condition = f"AND generated_at >= NOW() - make_interval(days => ${idx})"
params.append(lookback_days)
idx += 1
rows = await pool.fetch(
f"""SELECT confidence, direction_correct
FROM v_prediction_performance
WHERE horizon = $1
{lookback_condition}
AND confidence IS NOT NULL""",
*params,
)
# Group into calibration buckets
buckets_def = [
(0.50, 0.60),
(0.60, 0.70),
(0.70, 0.80),
(0.80, 0.90),
(0.90, 1.00),
]
buckets = []
for low, high in buckets_def:
bucket_rows = []
for r in rows:
conf = float(r["confidence"])
if high == 1.00:
in_bucket = low <= conf <= high
else:
in_bucket = low <= conf < high
if in_bucket:
bucket_rows.append(r)
count = len(bucket_rows)
if count == 0:
buckets.append({
"bucket_low": low,
"bucket_high": high,
"avg_confidence": 0.0,
"observed_win_rate": 0.0,
"prediction_count": 0,
"miscalibrated": False,
})
continue
avg_conf = sum(float(r["confidence"]) for r in bucket_rows) / count
win_count = sum(1 for r in bucket_rows if r["direction_correct"] is True)
win_rate = win_count / count
diff = abs(avg_conf - win_rate)
buckets.append({
"bucket_low": low,
"bucket_high": high,
"avg_confidence": round(avg_conf, 4),
"observed_win_rate": round(win_rate, 4),
"prediction_count": count,
"miscalibrated": diff > 0.15,
})
return {"buckets": buckets, "lookback": lookback, "horizon": horizon}
@app.get("/api/validation/ic-by-horizon")
async def get_validation_ic_by_horizon(
lookback: str = Query(default="30d"),
):
"""IC and Rank IC per prediction horizon.
Queries the most recent model_metric_snapshot for the given lookback
across all 5 horizons, returning IC and Rank IC for each.
Requirement 12.3
"""
if lookback not in _VALID_LOOKBACKS:
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
rows = await pool.fetch(
"""SELECT DISTINCT ON (horizon)
horizon,
information_coefficient,
rank_information_coefficient,
prediction_count,
generated_at
FROM model_metric_snapshots
WHERE lookback_window = $1
ORDER BY horizon, generated_at DESC""",
lookback,
)
horizons = []
for r in rows:
horizons.append({
"horizon": r["horizon"],
"information_coefficient": float(r["information_coefficient"]) if r["information_coefficient"] is not None else None,
"rank_information_coefficient": float(r["rank_information_coefficient"]) if r["rank_information_coefficient"] is not None else None,
"prediction_count": r["prediction_count"],
"generated_at": r["generated_at"].isoformat() if r["generated_at"] else None,
})
# Sort by canonical horizon order
horizon_order = {"1h": 0, "6h": 1, "1d": 2, "7d": 3, "30d": 4}
horizons.sort(key=lambda h: horizon_order.get(h["horizon"], 99))
return {"horizons": horizons, "lookback": lookback}
@app.get("/api/validation/gate-status")
async def get_validation_gate_status():
"""Quality gate evaluation detail.
Returns the stored gate evaluation result from risk_configs
where key = 'model_quality_gate'.
Requirement 12.7
"""
gate_row = await pool.fetchrow(
"SELECT config, updated_at FROM risk_configs WHERE name = 'model_quality_gate'",
)
if not gate_row:
return {
"gate_status": None,
"message": "No gate evaluation found. Model metrics may not have been computed yet.",
}
gate_data = _parse_jsonb(gate_row["config"])
updated_at = gate_row["updated_at"].isoformat() if gate_row.get("updated_at") else None
return {
"gate_status": gate_data,
"updated_at": updated_at,
}
# ---------------------------------------------------------------------------
# Attribution Endpoints (Requirements 12.4, 12.5, 12.6)
# ---------------------------------------------------------------------------
_LOOKBACK_TO_DAYS: dict[str, int] = {
"7d": 7,
"30d": 30,
"90d": 90,
"all": 3650,
}
@app.get("/api/validation/attribution/sources")
async def get_validation_attribution_sources(
lookback: str = Query(default="30d"),
horizon: str = Query(default="7d"),
):
"""Per-source performance metrics.
Returns win rate, IC, average return, duplicate rate, and other
attribution metrics for each source, computed over the given
lookback window and prediction horizon.
Requirement 12.4
"""
if lookback not in _VALID_LOOKBACKS:
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
if horizon not in _VALID_HORIZONS:
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
lookback_days = _LOOKBACK_TO_DAYS[lookback]
try:
results = await compute_source_attribution(pool, lookback_days=lookback_days, horizon=horizon)
except Exception:
logger.exception("Failed to compute source attribution")
raise HTTPException(500, "Failed to compute source attribution")
return {
"sources": [asdict(r) for r in results],
"lookback": lookback,
"horizon": horizon,
}
@app.get("/api/validation/attribution/catalysts")
async def get_validation_attribution_catalysts(
lookback: str = Query(default="30d"),
horizon: str = Query(default="7d"),
):
"""Per-catalyst-type performance metrics.
Returns win rate, IC, average return, and other attribution metrics
for each catalyst type, computed over the given lookback window
and prediction horizon.
Requirement 12.5
"""
if lookback not in _VALID_LOOKBACKS:
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
if horizon not in _VALID_HORIZONS:
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
lookback_days = _LOOKBACK_TO_DAYS[lookback]
try:
results = await compute_catalyst_attribution(pool, lookback_days=lookback_days, horizon=horizon)
except Exception:
logger.exception("Failed to compute catalyst attribution")
raise HTTPException(500, "Failed to compute catalyst attribution")
return {
"catalysts": [asdict(r) for r in results],
"lookback": lookback,
"horizon": horizon,
}
@app.get("/api/validation/attribution/layers")
async def get_validation_attribution_layers(
lookback: str = Query(default="30d"),
horizon: str = Query(default="7d"),
):
"""Per-signal-layer (company, macro, competitive) performance metrics.
Returns average contribution percentage, dominant win rate, and
dominant IC for each of the three signal layers, computed over
the given lookback window and prediction horizon.
Requirement 12.6
"""
if lookback not in _VALID_LOOKBACKS:
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
if horizon not in _VALID_HORIZONS:
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
lookback_days = _LOOKBACK_TO_DAYS[lookback]
try:
results = await compute_layer_attribution(pool, lookback_days=lookback_days, horizon=horizon)
except Exception:
logger.exception("Failed to compute layer attribution")
raise HTTPException(500, "Failed to compute layer attribution")
return {
"layers": [asdict(r) for r in results],
"lookback": lookback,
"horizon": horizon,
}
+103
View File
@@ -48,6 +48,7 @@ from services.shared.schemas import (
TrendSummary,
TrendWindow,
)
from services.validation.prediction_snapshot import create_prediction_snapshot
logger = logging.getLogger(__name__)
@@ -741,6 +742,92 @@ def _map_time_horizon_prefix(window: str) -> str:
return mapping.get(window, "window_")
# ---------------------------------------------------------------------------
# Fetch evidence signals and docs for prediction snapshot (Requirement 1.1)
# ---------------------------------------------------------------------------
_EVIDENCE_SIGNALS_QUERY = """
SELECT
dir.document_id::text AS document_id,
di.id::text AS signal_id,
dir.ticker,
d.source_type AS source,
d.source_type,
dir.catalyst_type,
dir.sentiment,
dir.impact_score AS impact,
di.confidence AS extraction_confidence,
di.source_credibility AS weight
FROM document_impact_records dir
JOIN document_intelligence di ON di.id = dir.intelligence_id
JOIN documents d ON d.id = di.document_id
WHERE dir.document_id = ANY($1::uuid[])
AND di.validation_status = 'valid'
"""
_EVIDENCE_DOCS_QUERY = """
SELECT
d.id::text AS document_id,
COALESCE(d.title, '') AS title,
COALESCE(d.url, '') AS url
FROM documents d
WHERE d.id = ANY($1::uuid[])
"""
async def _fetch_evidence_for_snapshot(
pool: asyncpg.Pool,
document_ids: list[str],
) -> tuple[list[dict], list[dict]]:
"""Fetch evidence signals and document metadata for prediction snapshot.
Filters out non-UUID document IDs (e.g. synthetic pattern IDs) since
they cannot be looked up in the documents table.
Returns (evidence_signals, evidence_docs).
"""
# Filter to valid UUIDs only
valid_ids: list[str] = []
for doc_id in document_ids:
try:
_uuid.UUID(doc_id)
valid_ids.append(doc_id)
except (ValueError, AttributeError):
continue
if not valid_ids:
return [], []
signal_rows = await pool.fetch(_EVIDENCE_SIGNALS_QUERY, valid_ids)
evidence_signals = [
{
"document_id": row["document_id"],
"signal_id": row["signal_id"],
"ticker": row["ticker"] or "",
"source": row["source"] or "",
"source_type": row["source_type"] or "",
"catalyst_type": row["catalyst_type"] or "",
"sentiment": row["sentiment"] or "",
"impact": float(row["impact"] or 0.0),
"extraction_confidence": float(row["extraction_confidence"] or 0.0),
"weight": float(row["weight"] or 0.0),
}
for row in signal_rows
]
doc_rows = await pool.fetch(_EVIDENCE_DOCS_QUERY, valid_ids)
evidence_docs = [
{
"document_id": row["document_id"],
"title": row["title"],
"url": row["url"],
}
for row in doc_rows
]
return evidence_signals, evidence_docs
async def generate_recommendation(
pool: asyncpg.Pool,
ticker: str,
@@ -847,6 +934,22 @@ async def generate_recommendation(
eligibility_result=result,
)
# 7b. Capture prediction snapshot for model validation (Requirements 1.1, 1.6)
try:
all_doc_ids = list(summary.top_supporting_evidence) + list(summary.top_opposing_evidence)
evidence_signals, evidence_docs = await _fetch_evidence_for_snapshot(
pool, all_doc_ids,
)
await create_prediction_snapshot(
pool, rec, summary, evidence_signals, evidence_docs,
)
except Exception:
logger.warning(
"Failed to create prediction snapshot for %s/%s — recommendation "
"persisted but snapshot creation failed",
ticker, rec_id, exc_info=True,
)
# 8. Publish prediction facts to analytical tables (Requirement 9.4)
if minio_client is not None:
try:
+243 -1
View File
@@ -4,6 +4,10 @@ Task 32: Fetches historical recommendations from the database, simulates
the decision logic chronologically using evaluate_recommendation(), tracks
simulated positions and equity curve, and persists results to backtest_runs
and backtest_trades tables.
Supports a validation mode (Requirements 15.115.5) that generates prediction
snapshots and evaluates outcomes using only data available at each historical
point in time, preventing future data leakage.
"""
from __future__ import annotations
@@ -39,12 +43,22 @@ class BacktestReplay:
self.pool = pool
self._perf = PerformanceComputer()
async def run(self, config: BacktestConfig, backtest_id: str | None = None) -> BacktestResult:
async def run(
self,
config: BacktestConfig,
backtest_id: str | None = None,
validation_mode: bool = False,
) -> BacktestResult:
"""Execute a full backtest replay.
Args:
config: Backtest configuration (date range, capital, risk tier).
backtest_id: Optional pre-generated ID. If not provided, one is generated.
validation_mode: When True, creates prediction snapshots for each
historical recommendation using only data available at that point
in time, evaluates outcomes, and computes model metrics over the
backtest period. Snapshots are tagged with the backtest_id.
(Requirements 15.115.5)
Returns:
BacktestResult with metrics, trade log, and equity curve.
@@ -87,6 +101,7 @@ class BacktestReplay:
daily_returns: list[float] = []
prev_value = config.initial_capital
trade_log: list[dict] = []
validation_snapshot_ids: list[str] = [] # track snapshot IDs for validation mode
# Pre-load company sectors and latest prices for enrichment
company_sectors: dict[str, str] = {}
@@ -172,6 +187,25 @@ class BacktestReplay:
now=sim_time,
)
# --- Validation mode: create prediction snapshot (Req 15.1, 15.2, 15.4) ---
if validation_mode and self.pool is not None:
try:
snapshot_id = await self._create_validation_snapshot(
rec=rec,
sim_time=sim_time,
backtest_id=backtest_id,
company_sectors=company_sectors,
)
if snapshot_id is not None:
validation_snapshot_ids.append(snapshot_id)
except Exception:
logger.warning(
"Validation snapshot failed for %s at %s, continuing backtest",
rec.get("ticker", "?"),
sim_time,
exc_info=True,
)
if decision.decision == "act":
act_count += 1
ticker = decision.ticker
@@ -348,6 +382,10 @@ class BacktestReplay:
# Persist results
await self._persist_results(result, closed_trades)
# --- Validation mode: evaluate outcomes and compute metrics (Req 15.3, 15.5) ---
if validation_mode and self.pool is not None and validation_snapshot_ids:
await self._run_validation_evaluation(backtest_id)
return result
except Exception as exc:
@@ -356,6 +394,210 @@ class BacktestReplay:
await self._persist_failed_run(backtest_id, config, str(exc))
raise
# ------------------------------------------------------------------
# Validation mode helpers (Requirements 15.115.5)
# ------------------------------------------------------------------
# SQL to fetch the close price at or before a specific time — prevents
# future data leakage by only returning data available at that point.
_CLOSE_AT_TIME_SQL = """
SELECT (data->>'c')::float AS close
FROM market_snapshots
WHERE ticker = $1
AND snapshot_type = 'bar'
AND data->>'c' IS NOT NULL
AND captured_at <= $2
ORDER BY captured_at DESC
LIMIT 1
"""
_COMPANY_SECTOR_SQL = """
SELECT sector FROM companies WHERE ticker = $1 AND active = TRUE LIMIT 1
"""
_SECTOR_ETF_MAP: dict[str, str] = {
"Technology": "XLK",
"Consumer Cyclical": "XLY",
"Financial Services": "XLF",
"Healthcare": "XLV",
"Energy": "XLE",
"Communication Services": "XLC",
"Industrials": "XLI",
"Consumer Defensive": "XLP",
"Real Estate": "XLRE",
"Utilities": "XLU",
}
async def _fetch_close_at_time(
self,
ticker: str,
target_time: datetime,
) -> float | None:
"""Fetch the close price for *ticker* at or before *target_time*.
Ensures no future data leakage — only market data with
``captured_at <= target_time`` is considered (Requirement 15.4).
"""
if self.pool is None:
return None
row = await self.pool.fetchrow(self._CLOSE_AT_TIME_SQL, ticker, target_time)
if row is None:
return None
return row["close"]
async def _create_validation_snapshot(
self,
rec: dict,
sim_time: datetime,
backtest_id: str,
company_sectors: dict[str, str],
) -> str | None:
"""Create a prediction snapshot using only data available at *sim_time*.
Fetches ticker, SPY, and sector ETF prices as of *sim_time* to prevent
future data leakage (Requirements 15.1, 15.2, 15.4). The snapshot is
tagged with *backtest_id* in its metadata field (Requirement 15.5).
Returns the snapshot UUID on success, or ``None`` on failure.
"""
from services.validation.prediction_snapshot import (
SECTOR_ETF_MAP,
)
ticker = rec.get("ticker", "")
if not ticker:
return None
# Fetch prices using only data available at sim_time (Req 15.4)
ticker_price = await self._fetch_close_at_time(ticker, sim_time)
spy_price = await self._fetch_close_at_time("SPY", sim_time)
# Sector ETF price
sector = company_sectors.get(ticker)
sector_etf_ticker = SECTOR_ETF_MAP.get(sector) if sector else None
sector_etf_price: float | None = None
if sector_etf_ticker is not None:
sector_etf_price = await self._fetch_close_at_time(
sector_etf_ticker, sim_time
)
snapshot_id = str(uuid.uuid4())
# Build metadata tagged with backtest_id (Req 15.5)
metadata: dict = {
"backtest_id": backtest_id,
"source": "backtest_validation",
}
# Map recommendation fields to snapshot columns
direction = rec.get("direction", rec.get("trend_direction", "neutral"))
action = rec.get("action", "watch")
mode = rec.get("mode", "informational")
confidence = float(rec.get("confidence", 0.5))
strength = float(rec.get("strength", rec.get("trend_strength", 0.5)))
contradiction = float(rec.get("contradiction", rec.get("contradiction_score", 0.0)))
p_bull = rec.get("p_bull")
if p_bull is not None:
p_bull = float(p_bull)
p_bear = (1.0 - p_bull) if p_bull is not None else None
window = rec.get("window", rec.get("trend_window", "7d"))
horizon = rec.get("time_horizon", rec.get("horizon", "7d"))
# Insert the snapshot directly — we bypass create_prediction_snapshot()
# because that function fetches *latest* prices (not point-in-time).
insert_sql = """
INSERT INTO prediction_snapshots (
id, generated_at, ticker, window, horizon, direction, action, mode,
strength, confidence, contradiction, p_bull, p_bear,
score_company, score_macro, score_competitive,
evidence_count, unique_source_count, duplicate_evidence_count,
price_at_prediction, spy_price_at_prediction,
sector_etf_price_at_prediction, metadata
) VALUES (
$1::uuid, $2, $3, $4, $5, $6, $7, $8,
$9, $10, $11, $12, $13,
$14, $15, $16,
$17, $18, $19,
$20, $21, $22,
$23::jsonb
)
"""
await self.pool.execute(
insert_sql,
snapshot_id,
sim_time,
ticker,
str(window),
str(horizon),
str(direction),
str(action),
str(mode),
strength,
confidence,
contradiction,
p_bull,
p_bear,
float(rec.get("score_company", 0.0)),
float(rec.get("score_macro", 0.0)),
float(rec.get("score_competitive", 0.0)),
int(rec.get("evidence_count", 0)),
int(rec.get("unique_source_count", 0)),
int(rec.get("duplicate_evidence_count", 0)),
ticker_price,
spy_price,
sector_etf_price,
json.dumps(metadata),
)
logger.debug(
"Validation snapshot %s created for %s at %s (backtest %s)",
snapshot_id,
ticker,
sim_time,
backtest_id,
)
return snapshot_id
async def _run_validation_evaluation(self, backtest_id: str) -> None:
"""Evaluate prediction outcomes and compute metrics for the backtest.
Calls the outcome evaluator and metrics engine after the backtest
completes (Requirements 15.3, 15.5). Failures are logged but do
not block the backtest result.
"""
from services.validation.metrics import compute_and_store_metric_snapshots
from services.validation.outcome_evaluator import evaluate_matured_predictions
# Step 1: Evaluate matured predictions (Req 15.3)
try:
outcomes_count = await evaluate_matured_predictions(self.pool)
logger.info(
"Backtest %s validation: %d prediction outcomes evaluated",
backtest_id,
outcomes_count,
)
except Exception:
logger.warning(
"Backtest %s: outcome evaluation failed, continuing",
backtest_id,
exc_info=True,
)
# Step 2: Compute and store metric snapshots (Req 15.5)
try:
snapshots = await compute_and_store_metric_snapshots(self.pool)
logger.info(
"Backtest %s validation: %d metric snapshots computed",
backtest_id,
len(snapshots),
)
except Exception:
logger.warning(
"Backtest %s: metric snapshot computation failed, continuing",
backtest_id,
exc_info=True,
)
# ------------------------------------------------------------------
# Database helpers
# ------------------------------------------------------------------
+329
View File
@@ -0,0 +1,329 @@
"""Quality gate for live trading eligibility.
Evaluates aggregate model metrics against configurable thresholds and
determines whether the system meets minimum quality standards for live
trading. When any threshold is not met, the gate forces all
recommendations to paper mode (fail-safe).
Requirements: 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
import asyncpg
logger = logging.getLogger("trading_engine.quality_gate")
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass
class QualityGateConfig:
"""Configurable thresholds for live trading eligibility."""
min_prediction_count: int = 100
min_ic: float = 0.03
min_win_rate: float = 0.53
max_ece: float = 0.15
min_excess_return_vs_spy: float = 0.0
max_snapshot_age_hours: int = 24
@dataclass
class GateThresholdResult:
"""Result for a single threshold check."""
name: str
threshold: float
actual: float
passed: bool
@dataclass
class QualityGateResult:
"""Full gate evaluation result."""
passed: bool
evaluated_at: datetime
threshold_results: list[GateThresholdResult] = field(default_factory=list)
reason: str = ""
snapshot_id: str | None = None
config: QualityGateConfig = field(default_factory=QualityGateConfig)
# ---------------------------------------------------------------------------
# Threshold evaluation helpers
# ---------------------------------------------------------------------------
def _evaluate_thresholds(
snapshot: dict,
config: QualityGateConfig,
) -> list[GateThresholdResult]:
"""Evaluate each threshold against snapshot metric values."""
results: list[GateThresholdResult] = []
# min_prediction_count
actual_count = snapshot.get("prediction_count") or 0
results.append(
GateThresholdResult(
name="min_prediction_count",
threshold=float(config.min_prediction_count),
actual=float(actual_count),
passed=actual_count >= config.min_prediction_count,
)
)
# min_ic
actual_ic = snapshot.get("information_coefficient")
if actual_ic is None:
actual_ic = 0.0
results.append(
GateThresholdResult(
name="min_ic",
threshold=config.min_ic,
actual=float(actual_ic),
passed=float(actual_ic) >= config.min_ic,
)
)
# min_win_rate
actual_wr = snapshot.get("win_rate")
if actual_wr is None:
actual_wr = 0.0
results.append(
GateThresholdResult(
name="min_win_rate",
threshold=config.min_win_rate,
actual=float(actual_wr),
passed=float(actual_wr) >= config.min_win_rate,
)
)
# max_ece (calibration_error)
actual_ece = snapshot.get("calibration_error")
if actual_ece is None:
actual_ece = 1.0 # worst-case when missing
results.append(
GateThresholdResult(
name="max_ece",
threshold=config.max_ece,
actual=float(actual_ece),
passed=float(actual_ece) <= config.max_ece,
)
)
# min_excess_return_vs_spy
actual_excess = snapshot.get("avg_excess_return_vs_spy")
if actual_excess is None:
actual_excess = 0.0
results.append(
GateThresholdResult(
name="min_excess_return_vs_spy",
threshold=config.min_excess_return_vs_spy,
actual=float(actual_excess),
passed=float(actual_excess) >= config.min_excess_return_vs_spy,
)
)
return results
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def evaluate_quality_gate(
pool: asyncpg.Pool,
config: QualityGateConfig | None = None,
) -> QualityGateResult:
"""Evaluate model quality gate from latest metric snapshot.
Reads the most recent ``model_metric_snapshot`` for the 30d lookback
and 7d horizon (the primary evaluation window).
If no snapshot exists or snapshot is stale (>max_snapshot_age_hours),
defaults to paper-only mode (fail-safe).
Stores result in ``risk_configs`` under ``'model_quality_gate'`` key.
"""
if config is None:
config = await load_gate_config_from_db(pool)
now = datetime.now(tz=timezone.utc)
# Fetch the most recent metric snapshot for 30d lookback / 7d horizon
try:
row = await pool.fetchrow(
"""SELECT id, generated_at, prediction_count, win_rate,
directional_accuracy, information_coefficient,
rank_information_coefficient, avg_return,
avg_excess_return_vs_spy, avg_excess_return_vs_sector,
calibration_error, brier_score,
buy_win_rate, sell_win_rate, hold_win_rate
FROM model_metric_snapshots
WHERE lookback_window = '30d' AND horizon = '7d'
ORDER BY generated_at DESC
LIMIT 1""",
)
except Exception:
logger.exception("Failed to query model_metric_snapshots")
row = None
# Fail-safe: no snapshot exists
if row is None:
result = QualityGateResult(
passed=False,
evaluated_at=now,
threshold_results=[],
reason="no model metric snapshot available — defaulting to paper-only",
snapshot_id=None,
config=config,
)
logger.warning("Quality gate: %s", result.reason)
await _store_gate_result(pool, result)
return result
snapshot = dict(row)
snapshot_id = str(snapshot["id"])
generated_at = snapshot["generated_at"]
# Fail-safe: stale snapshot
age_hours = (now - generated_at).total_seconds() / 3600.0
if age_hours > config.max_snapshot_age_hours:
result = QualityGateResult(
passed=False,
evaluated_at=now,
threshold_results=[],
reason=(
f"most recent snapshot is {age_hours:.1f}h old "
f"(max {config.max_snapshot_age_hours}h) — defaulting to paper-only"
),
snapshot_id=snapshot_id,
config=config,
)
logger.warning("Quality gate: %s", result.reason)
await _store_gate_result(pool, result)
return result
# Evaluate thresholds
threshold_results = _evaluate_thresholds(snapshot, config)
failed = [r for r in threshold_results if not r.passed]
if failed:
failed_names = ", ".join(
f"{r.name}(actual={r.actual:.4f}, threshold={r.threshold:.4f})"
for r in failed
)
reason = f"failed: {failed_names}"
passed = False
else:
reason = "all thresholds met"
passed = True
result = QualityGateResult(
passed=passed,
evaluated_at=now,
threshold_results=threshold_results,
reason=reason,
snapshot_id=snapshot_id,
config=config,
)
# Log details
for tr in threshold_results:
logger.info(
"Quality gate threshold %s: actual=%.4f threshold=%.4f %s",
tr.name,
tr.actual,
tr.threshold,
"PASS" if tr.passed else "FAIL",
)
logger.info("Quality gate result: %s%s", "PASS" if passed else "FAIL", reason)
await _store_gate_result(pool, result)
return result
async def load_gate_config_from_db(
pool: asyncpg.Pool,
) -> QualityGateConfig:
"""Load gate thresholds from risk_configs, with defaults.
Looks for a ``risk_configs`` row with ``name = 'model_quality_gate_config'``.
If found, merges stored thresholds over the defaults. If not found or
the stored JSON is invalid, returns the default config.
"""
defaults = QualityGateConfig()
try:
row = await pool.fetchrow(
"SELECT config FROM risk_configs WHERE name = 'model_quality_gate_config'",
)
except Exception:
logger.warning("Failed to load gate config from risk_configs — using defaults")
return defaults
if row is None:
return defaults
try:
raw = row["config"]
cfg = raw if isinstance(raw, dict) else json.loads(raw)
except (json.JSONDecodeError, TypeError):
logger.warning("Invalid gate config JSON in risk_configs — using defaults")
return defaults
return QualityGateConfig(
min_prediction_count=int(cfg.get("min_prediction_count", defaults.min_prediction_count)),
min_ic=float(cfg.get("min_ic", defaults.min_ic)),
min_win_rate=float(cfg.get("min_win_rate", defaults.min_win_rate)),
max_ece=float(cfg.get("max_ece", defaults.max_ece)),
min_excess_return_vs_spy=float(
cfg.get("min_excess_return_vs_spy", defaults.min_excess_return_vs_spy)
),
max_snapshot_age_hours=int(
cfg.get("max_snapshot_age_hours", defaults.max_snapshot_age_hours)
),
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _gate_result_to_json(result: QualityGateResult) -> str:
"""Serialize a QualityGateResult to JSON for storage in risk_configs."""
payload = {
"passed": result.passed,
"evaluated_at": result.evaluated_at.isoformat(),
"reason": result.reason,
"snapshot_id": result.snapshot_id,
"config": asdict(result.config),
"threshold_results": [asdict(tr) for tr in result.threshold_results],
}
return json.dumps(payload, default=str)
async def _store_gate_result(pool: asyncpg.Pool, result: QualityGateResult) -> None:
"""Upsert gate evaluation result into risk_configs."""
payload = _gate_result_to_json(result)
try:
await pool.execute(
"""INSERT INTO risk_configs (name, config, updated_at)
VALUES ('model_quality_gate', $1::jsonb, NOW())
ON CONFLICT (name) WHERE active = TRUE
DO UPDATE SET config = $1::jsonb, updated_at = NOW()""",
payload,
)
except Exception:
logger.exception("Failed to store quality gate result in risk_configs")
+1
View File
@@ -0,0 +1 @@
+591
View File
@@ -0,0 +1,591 @@
"""Attribution Engine — per-source, per-catalyst, and per-layer performance.
Joins signal evidence links with prediction outcomes to compute attribution
metrics that identify which sources, catalyst types, and signal layers
contribute most to accurate predictions.
Requirements: 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
from datetime import datetime, timedelta
import asyncpg
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class SourceAttribution:
"""Performance metrics for a single source."""
source: str
source_type: str
prediction_count: int
avg_weight: float
avg_contribution_score: float
win_rate: float
avg_future_return: float
avg_excess_return_vs_spy: float
information_coefficient: float | None
duplicate_rate: float
@dataclass
class CatalystAttribution:
"""Performance metrics for a single catalyst type."""
catalyst_type: str
prediction_count: int
win_rate: float
avg_future_return: float
avg_excess_return_vs_spy: float
information_coefficient: float | None
@dataclass
class LayerAttribution:
"""Performance metrics for a signal layer."""
layer: str # company, macro, competitive
avg_contribution_pct: float
dominant_win_rate: float # win rate when this layer > 30% contribution
dominant_ic: float | None # IC when this layer > 30% contribution
# ---------------------------------------------------------------------------
# Pure computation helpers
# ---------------------------------------------------------------------------
def _pearson_correlation(xs: list[float], ys: list[float]) -> float | None:
"""Compute Pearson correlation coefficient between two lists.
Returns None if the lists have fewer than 2 elements or if either
has zero variance. Guards against NaN/infinity.
"""
n = len(xs)
if n < 2:
return None
mean_x = sum(xs) / n
mean_y = sum(ys) / n
cov = 0.0
var_x = 0.0
var_y = 0.0
for x, y in zip(xs, ys):
dx = x - mean_x
dy = y - mean_y
cov += dx * dy
var_x += dx * dx
var_y += dy * dy
if var_x == 0.0 or var_y == 0.0:
return None
r = cov / math.sqrt(var_x * var_y)
if math.isnan(r) or math.isinf(r):
return None
return max(-1.0, min(1.0, r))
def _compute_ic(
contribution_scores: list[float],
future_returns: list[float],
) -> float | None:
"""Compute IC (Pearson correlation) between contribution scores and returns.
Returns None when fewer than 30 data points.
"""
if len(contribution_scores) < 30 or len(future_returns) < 30:
return None
n = min(len(contribution_scores), len(future_returns))
return _pearson_correlation(contribution_scores[:n], future_returns[:n])
# ---------------------------------------------------------------------------
# SQL queries — source attribution via v_source_performance
# ---------------------------------------------------------------------------
_SOURCE_ATTRIBUTION_SQL = """
SELECT
source,
source_type,
weight,
contribution_score,
is_duplicate,
direction_correct,
future_return,
excess_return_vs_spy
FROM v_source_performance
WHERE horizon = $1
AND generated_at >= $2
"""
_SOURCE_ATTRIBUTION_ALL_SQL = """
SELECT
source,
source_type,
weight,
contribution_score,
is_duplicate,
direction_correct,
future_return,
excess_return_vs_spy
FROM v_source_performance
WHERE horizon = $1
"""
# ---------------------------------------------------------------------------
# SQL queries — catalyst attribution via v_source_performance
# ---------------------------------------------------------------------------
_CATALYST_ATTRIBUTION_SQL = """
SELECT
catalyst_type,
weight,
contribution_score,
direction_correct,
future_return,
excess_return_vs_spy
FROM v_source_performance
WHERE horizon = $1
AND generated_at >= $2
"""
_CATALYST_ATTRIBUTION_ALL_SQL = """
SELECT
catalyst_type,
weight,
contribution_score,
direction_correct,
future_return,
excess_return_vs_spy
FROM v_source_performance
WHERE horizon = $1
"""
# ---------------------------------------------------------------------------
# SQL queries — layer attribution via prediction_snapshots + outcomes
# ---------------------------------------------------------------------------
_LAYER_ATTRIBUTION_SQL = """
SELECT
ps.score_company,
ps.score_macro,
ps.score_competitive,
po.direction_correct,
po.future_return
FROM prediction_snapshots ps
JOIN prediction_outcomes po ON po.prediction_id = ps.id
WHERE po.horizon = $1
AND ps.generated_at >= $2
"""
_LAYER_ATTRIBUTION_ALL_SQL = """
SELECT
ps.score_company,
ps.score_macro,
ps.score_competitive,
po.direction_correct,
po.future_return
FROM prediction_snapshots ps
JOIN prediction_outcomes po ON po.prediction_id = ps.id
WHERE po.horizon = $1
"""
# ---------------------------------------------------------------------------
# Source attribution (Requirements 7.1, 7.2, 7.7)
# ---------------------------------------------------------------------------
async def compute_source_attribution(
pool: asyncpg.Pool,
lookback_days: int = 30,
horizon: str = "7d",
) -> list[SourceAttribution]:
"""Compute per-source performance metrics.
Queries v_source_performance, groups by source, and computes:
prediction count, avg weight, avg contribution score, win rate,
avg future return, avg excess return vs SPY, IC, and duplicate rate.
Returns a list of SourceAttribution sorted by prediction count descending.
"""
now = datetime.now().astimezone()
cutoff = now - timedelta(days=lookback_days)
try:
rows = await pool.fetch(_SOURCE_ATTRIBUTION_SQL, horizon, cutoff)
except Exception:
logger.exception(
"Failed to query source attribution for horizon=%s lookback=%dd",
horizon,
lookback_days,
)
return []
if not rows:
return []
# Group rows by source
source_groups: dict[str, list[dict]] = {}
for row in rows:
r = dict(row)
key = r.get("source") or "unknown"
source_groups.setdefault(key, []).append(r)
results: list[SourceAttribution] = []
for source, group in source_groups.items():
count = len(group)
# Source type — take the most common one
source_type = group[0].get("source_type") or "unknown"
# Avg weight
weights = [r["weight"] for r in group if r.get("weight") is not None]
avg_weight = sum(weights) / len(weights) if weights else 0.0
# Avg contribution score
contrib_scores = [
r["contribution_score"]
for r in group
if r.get("contribution_score") is not None
]
avg_contribution_score = (
sum(contrib_scores) / len(contrib_scores) if contrib_scores else 0.0
)
# Win rate
direction_rows = [r for r in group if r.get("direction_correct") is not None]
win_count = sum(1 for r in direction_rows if r["direction_correct"] is True)
win_rate = win_count / len(direction_rows) if direction_rows else 0.0
# Avg future return
returns = [
r["future_return"] for r in group if r.get("future_return") is not None
]
avg_future_return = sum(returns) / len(returns) if returns else 0.0
# Avg excess return vs SPY
excess_returns = [
r["excess_return_vs_spy"]
for r in group
if r.get("excess_return_vs_spy") is not None
]
avg_excess_return_vs_spy = (
sum(excess_returns) / len(excess_returns) if excess_returns else 0.0
)
# IC: correlation between contribution scores and future returns
ic_scores = [
r["contribution_score"]
for r in group
if r.get("contribution_score") is not None
and r.get("future_return") is not None
]
ic_returns = [
r["future_return"]
for r in group
if r.get("contribution_score") is not None
and r.get("future_return") is not None
]
ic = _compute_ic(ic_scores, ic_returns)
# Duplicate rate: is_duplicate=true / total
dup_count = sum(1 for r in group if r.get("is_duplicate") is True)
duplicate_rate = dup_count / count
results.append(
SourceAttribution(
source=source,
source_type=source_type,
prediction_count=count,
avg_weight=avg_weight,
avg_contribution_score=avg_contribution_score,
win_rate=win_rate,
avg_future_return=avg_future_return,
avg_excess_return_vs_spy=avg_excess_return_vs_spy,
information_coefficient=ic,
duplicate_rate=duplicate_rate,
)
)
# Sort by prediction count descending
results.sort(key=lambda a: a.prediction_count, reverse=True)
logger.info(
"Computed source attribution for %d sources (horizon=%s, lookback=%dd)",
len(results),
horizon,
lookback_days,
)
return results
# ---------------------------------------------------------------------------
# Catalyst attribution (Requirements 7.3, 7.4)
# ---------------------------------------------------------------------------
async def compute_catalyst_attribution(
pool: asyncpg.Pool,
lookback_days: int = 30,
horizon: str = "7d",
) -> list[CatalystAttribution]:
"""Compute per-catalyst-type performance metrics.
Queries v_source_performance, groups by catalyst_type, and computes:
prediction count, win rate, avg future return, avg excess return vs SPY,
and IC.
Returns a list of CatalystAttribution sorted by prediction count descending.
"""
now = datetime.now().astimezone()
cutoff = now - timedelta(days=lookback_days)
try:
rows = await pool.fetch(_CATALYST_ATTRIBUTION_SQL, horizon, cutoff)
except Exception:
logger.exception(
"Failed to query catalyst attribution for horizon=%s lookback=%dd",
horizon,
lookback_days,
)
return []
if not rows:
return []
# Group rows by catalyst_type
catalyst_groups: dict[str, list[dict]] = {}
for row in rows:
r = dict(row)
key = r.get("catalyst_type") or "unknown"
catalyst_groups.setdefault(key, []).append(r)
results: list[CatalystAttribution] = []
for catalyst_type, group in catalyst_groups.items():
count = len(group)
# Win rate
direction_rows = [r for r in group if r.get("direction_correct") is not None]
win_count = sum(1 for r in direction_rows if r["direction_correct"] is True)
win_rate = win_count / len(direction_rows) if direction_rows else 0.0
# Avg future return
returns = [
r["future_return"] for r in group if r.get("future_return") is not None
]
avg_future_return = sum(returns) / len(returns) if returns else 0.0
# Avg excess return vs SPY
excess_returns = [
r["excess_return_vs_spy"]
for r in group
if r.get("excess_return_vs_spy") is not None
]
avg_excess_return_vs_spy = (
sum(excess_returns) / len(excess_returns) if excess_returns else 0.0
)
# IC: correlation between contribution scores and future returns
ic_scores = [
r["contribution_score"]
for r in group
if r.get("contribution_score") is not None
and r.get("future_return") is not None
]
ic_returns = [
r["future_return"]
for r in group
if r.get("contribution_score") is not None
and r.get("future_return") is not None
]
ic = _compute_ic(ic_scores, ic_returns)
results.append(
CatalystAttribution(
catalyst_type=catalyst_type,
prediction_count=count,
win_rate=win_rate,
avg_future_return=avg_future_return,
avg_excess_return_vs_spy=avg_excess_return_vs_spy,
information_coefficient=ic,
)
)
# Sort by prediction count descending
results.sort(key=lambda a: a.prediction_count, reverse=True)
logger.info(
"Computed catalyst attribution for %d catalyst types "
"(horizon=%s, lookback=%dd)",
len(results),
horizon,
lookback_days,
)
return results
# ---------------------------------------------------------------------------
# Layer attribution (Requirements 7.5, 7.6)
# ---------------------------------------------------------------------------
async def compute_layer_attribution(
pool: asyncpg.Pool,
lookback_days: int = 30,
horizon: str = "7d",
) -> list[LayerAttribution]:
"""Compute per-layer (company, macro, competitive) performance metrics.
Queries prediction_snapshots joined with prediction_outcomes to get
score_company, score_macro, score_competitive alongside outcomes.
For each layer computes:
- avg_contribution_pct: average of layer_score / total_score across all
predictions (where total_score > 0)
- dominant_win_rate: win rate for predictions where the layer contributes
more than 30% of the total score
- dominant_ic: IC (Pearson correlation between layer score and future
return) for predictions where the layer contributes > 30%
Returns a list of 3 LayerAttribution objects (company, macro, competitive).
"""
now = datetime.now().astimezone()
cutoff = now - timedelta(days=lookback_days)
try:
rows = await pool.fetch(_LAYER_ATTRIBUTION_SQL, horizon, cutoff)
except Exception:
logger.exception(
"Failed to query layer attribution for horizon=%s lookback=%dd",
horizon,
lookback_days,
)
return []
if not rows:
return [
LayerAttribution(
layer="company",
avg_contribution_pct=0.0,
dominant_win_rate=0.0,
dominant_ic=None,
),
LayerAttribution(
layer="macro",
avg_contribution_pct=0.0,
dominant_win_rate=0.0,
dominant_ic=None,
),
LayerAttribution(
layer="competitive",
avg_contribution_pct=0.0,
dominant_win_rate=0.0,
dominant_ic=None,
),
]
row_dicts = [dict(r) for r in rows]
layers = [
("company", "score_company"),
("macro", "score_macro"),
("competitive", "score_competitive"),
]
results: list[LayerAttribution] = []
for layer_name, score_field in layers:
# --- Average contribution percentage ---
contribution_pcts: list[float] = []
for r in row_dicts:
total = (
(r.get("score_company") or 0.0)
+ (r.get("score_macro") or 0.0)
+ (r.get("score_competitive") or 0.0)
)
if total > 0.0:
layer_score = r.get(score_field) or 0.0
contribution_pcts.append(layer_score / total)
avg_contribution_pct = (
sum(contribution_pcts) / len(contribution_pcts)
if contribution_pcts
else 0.0
)
# --- Dominant predictions: layer > 30% of total score ---
dominant_rows: list[dict] = []
for r in row_dicts:
total = (
(r.get("score_company") or 0.0)
+ (r.get("score_macro") or 0.0)
+ (r.get("score_competitive") or 0.0)
)
if total > 0.0:
layer_score = r.get(score_field) or 0.0
if layer_score / total > 0.30:
dominant_rows.append(r)
# Dominant win rate
dominant_direction_rows = [
r for r in dominant_rows if r.get("direction_correct") is not None
]
dominant_win_count = sum(
1 for r in dominant_direction_rows if r["direction_correct"] is True
)
dominant_win_rate = (
dominant_win_count / len(dominant_direction_rows)
if dominant_direction_rows
else 0.0
)
# Dominant IC: correlation between layer score and future return
dom_scores = [
r.get(score_field) or 0.0
for r in dominant_rows
if r.get("future_return") is not None
]
dom_returns = [
r["future_return"]
for r in dominant_rows
if r.get("future_return") is not None
]
dominant_ic = _compute_ic(dom_scores, dom_returns)
results.append(
LayerAttribution(
layer=layer_name,
avg_contribution_pct=avg_contribution_pct,
dominant_win_rate=dominant_win_rate,
dominant_ic=dominant_ic,
)
)
logger.info(
"Computed layer attribution for 3 layers (horizon=%s, lookback=%dd)",
horizon,
lookback_days,
)
return results
+135
View File
@@ -0,0 +1,135 @@
"""Calibration Engine — Bayesian shrinkage source reliability and weight adjustment.
Computes source reliability scores using Bayesian shrinkage from historical
prediction outcomes, and adjusts evidence weights based on source performance.
Updates the existing source_accuracy table with reliability scores.
Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
"""
from __future__ import annotations
import logging
import asyncpg
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Pure functions — testable without a database
# ---------------------------------------------------------------------------
def compute_source_reliability(
observed_win_rate: float,
sample_count: int,
prior_strength: int = 30,
) -> float:
"""Bayesian shrinkage source reliability.
reliability = 0.5 + (n / (n + prior_strength)) * (observed_win_rate - 0.5)
Returns value in [0.0, 1.0].
When n=0, returns 0.5 (prior mean).
As n→∞, approaches observed_win_rate.
"""
if sample_count <= 0:
return 0.5
shrinkage = sample_count / (sample_count + prior_strength)
reliability = 0.5 + shrinkage * (observed_win_rate - 0.5)
# Clamp to [0.0, 1.0] for safety (should already be in range when
# observed_win_rate is in [0.0, 1.0], but guard against edge cases).
return max(0.0, min(1.0, reliability))
def compute_adjusted_evidence_weight(
base_weight: float,
reliability: float,
) -> float:
"""Adjusted weight = base_weight * (0.5 + reliability), clamped to [0.1, 2.0]."""
adjusted = base_weight * (0.5 + reliability)
return max(0.1, min(2.0, adjusted))
# ---------------------------------------------------------------------------
# SQL queries
# ---------------------------------------------------------------------------
# Query v_source_performance to get per-source win rates and sample counts.
# Groups by source, counting total predictions and directional wins.
_SOURCE_PERFORMANCE_SQL = """
SELECT
source,
COUNT(*) AS sample_count,
COUNT(*) FILTER (WHERE direction_correct = TRUE) AS win_count
FROM v_source_performance
WHERE direction_correct IS NOT NULL
GROUP BY source
"""
# Upsert into source_accuracy: update accuracy_ratio and sample_count
# for existing sources, insert new ones.
_UPSERT_SOURCE_ACCURACY_SQL = """
INSERT INTO source_accuracy (source_id, accuracy_ratio, sample_count, last_updated)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (source_id)
DO UPDATE SET
accuracy_ratio = EXCLUDED.accuracy_ratio,
sample_count = EXCLUDED.sample_count,
last_updated = NOW()
"""
# ---------------------------------------------------------------------------
# Database-backed function
# ---------------------------------------------------------------------------
async def update_source_reliabilities(
pool: asyncpg.Pool,
) -> int:
"""Recompute and store source reliability scores from latest outcomes.
1. Queries v_source_performance to get per-source win rates and counts
2. Computes Bayesian shrinkage reliability for each source
3. Upserts into source_accuracy table (accuracy_ratio = reliability)
Returns count of sources updated.
"""
try:
rows = await pool.fetch(_SOURCE_PERFORMANCE_SQL)
except Exception:
logger.exception("Failed to query source performance for reliability update")
return 0
if not rows:
logger.info("No source performance data available for reliability update")
return 0
updated = 0
for row in rows:
source = row["source"]
sample_count = row["sample_count"]
win_count = row["win_count"]
observed_win_rate = win_count / sample_count if sample_count > 0 else 0.5
reliability = compute_source_reliability(observed_win_rate, sample_count)
try:
await pool.execute(
_UPSERT_SOURCE_ACCURACY_SQL,
source,
reliability,
sample_count,
)
updated += 1
except Exception:
logger.exception(
"Failed to upsert source reliability for source=%s", source
)
logger.info("Updated source reliabilities for %d sources", updated)
return updated
+637
View File
@@ -0,0 +1,637 @@
"""Metrics Engine — computes calibration, IC, Brier, and benchmark metrics.
Aggregates model quality metrics across configurable lookback windows and
prediction horizons. Stores periodic snapshots for time-series analysis
of model performance trends.
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 6.1, 6.2, 6.3, 6.4, 6.5,
9.1, 9.2, 9.3, 9.4, 10.1, 10.2, 10.3, 10.4, 10.5
"""
from __future__ import annotations
import json
import logging
import math
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import asyncpg
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CONFIDENCE_BUCKETS: list[tuple[float, float]] = [
(0.50, 0.60),
(0.60, 0.70),
(0.70, 0.80),
(0.80, 0.90),
(0.90, 1.00),
]
LOOKBACK_WINDOWS: list[str] = ["7d", "30d", "90d", "all"]
LOOKBACK_DURATIONS: dict[str, timedelta | None] = {
"7d": timedelta(days=7),
"30d": timedelta(days=30),
"90d": timedelta(days=90),
"all": None,
}
EVALUATION_HORIZONS: list[str] = ["1h", "6h", "1d", "7d", "30d"]
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class CalibrationBucket:
"""Calibration metrics for a single confidence bucket."""
bucket_low: float
bucket_high: float
avg_confidence: float
observed_win_rate: float
prediction_count: int
miscalibrated: bool # |avg_confidence - win_rate| > 0.15
@dataclass
class ModelMetricSnapshot:
"""Aggregate model quality metrics for a lookback/horizon combination."""
id: str
generated_at: datetime
lookback_window: str
horizon: str
prediction_count: int
win_rate: float
directional_accuracy: float
information_coefficient: float | None
rank_information_coefficient: float | None
avg_return: float
avg_excess_return_vs_spy: float
avg_excess_return_vs_sector: float
calibration_error: float # ECE
brier_score: float
buy_win_rate: float
sell_win_rate: float
hold_win_rate: float
metadata: dict = field(default_factory=dict)
# ---------------------------------------------------------------------------
# Pure computation functions
# ---------------------------------------------------------------------------
def compute_calibration_error(
confidences: list[float],
outcomes: list[bool],
) -> tuple[float, list[CalibrationBucket]]:
"""Compute ECE and calibration buckets.
ECE = Σ (n_b / N) * |avg_conf_b - win_rate_b|
Groups predictions into 5 confidence buckets and computes the weighted
average of |avg_confidence - observed_win_rate| across all buckets.
Flags buckets where |diff| > 0.15 as miscalibrated.
Returns (ece, buckets). Returns (0.0, []) when no data is provided.
"""
if not confidences or not outcomes:
return 0.0, []
n = len(confidences)
buckets: list[CalibrationBucket] = []
ece = 0.0
for low, high in CONFIDENCE_BUCKETS:
bucket_confs: list[float] = []
bucket_outcomes: list[bool] = []
for conf, outcome in zip(confidences, outcomes):
# Last bucket is inclusive on the right: [0.90, 1.00]
if high == 1.00:
in_bucket = low <= conf <= high
else:
in_bucket = low <= conf < high
if in_bucket:
bucket_confs.append(conf)
bucket_outcomes.append(outcome)
count = len(bucket_confs)
if count == 0:
# Empty bucket — exclude from ECE, still record it
buckets.append(
CalibrationBucket(
bucket_low=low,
bucket_high=high,
avg_confidence=0.0,
observed_win_rate=0.0,
prediction_count=0,
miscalibrated=False,
)
)
continue
avg_conf = sum(bucket_confs) / count
win_rate = sum(1.0 for o in bucket_outcomes if o) / count
diff = abs(avg_conf - win_rate)
miscalibrated = diff > 0.15
buckets.append(
CalibrationBucket(
bucket_low=low,
bucket_high=high,
avg_confidence=avg_conf,
observed_win_rate=win_rate,
prediction_count=count,
miscalibrated=miscalibrated,
)
)
ece += (count / n) * diff
return ece, buckets
def compute_brier_score(
p_bulls: list[float],
outcomes: list[bool],
) -> float:
"""Brier score = mean((p_bull - outcome)^2).
outcome is 1.0 when price moved in predicted direction, 0.0 otherwise.
Returns value in [0.0, 1.0]. Returns 0.0 for empty input.
"""
if not p_bulls or not outcomes:
return 0.0
n = len(p_bulls)
total = 0.0
for p, o in zip(p_bulls, outcomes):
actual = 1.0 if o else 0.0
total += (p - actual) ** 2
return total / n
def _pearson_correlation(xs: list[float], ys: list[float]) -> float | None:
"""Compute Pearson correlation coefficient between two lists.
Returns None if the lists have fewer than 2 elements or if either
has zero variance. Guards against NaN/infinity.
"""
n = len(xs)
if n < 2:
return None
mean_x = sum(xs) / n
mean_y = sum(ys) / n
cov = 0.0
var_x = 0.0
var_y = 0.0
for x, y in zip(xs, ys):
dx = x - mean_x
dy = y - mean_y
cov += dx * dy
var_x += dx * dx
var_y += dy * dy
if var_x == 0.0 or var_y == 0.0:
return None
r = cov / math.sqrt(var_x * var_y)
# Guard against floating-point drift
if math.isnan(r) or math.isinf(r):
return None
# Clamp to [-1.0, 1.0]
return max(-1.0, min(1.0, r))
def _rank_data(values: list[float]) -> list[float]:
"""Compute fractional ranks for a list of values (average tie-breaking)."""
n = len(values)
indexed = sorted(range(n), key=lambda i: values[i])
ranks = [0.0] * n
i = 0
while i < n:
# Find the end of the tie group
j = i + 1
while j < n and values[indexed[j]] == values[indexed[i]]:
j += 1
# Average rank for the tie group (1-based)
avg_rank = (i + j + 1) / 2.0
for k in range(i, j):
ranks[indexed[k]] = avg_rank
i = j
return ranks
def compute_information_coefficient(
scores: list[float],
returns: list[float],
) -> float | None:
"""Pearson correlation between prediction scores and future returns.
Returns None when fewer than 30 data points.
Returns value in [-1.0, 1.0].
"""
if len(scores) < 30 or len(returns) < 30:
return None
n = min(len(scores), len(returns))
return _pearson_correlation(scores[:n], returns[:n])
def compute_rank_information_coefficient(
scores: list[float],
returns: list[float],
) -> float | None:
"""Spearman rank correlation between prediction scores and future returns.
Ranks the data and computes Pearson correlation on the ranks.
Returns None when fewer than 30 data points.
Returns value in [-1.0, 1.0].
"""
if len(scores) < 30 or len(returns) < 30:
return None
n = min(len(scores), len(returns))
ranked_scores = _rank_data(scores[:n])
ranked_returns = _rank_data(returns[:n])
return _pearson_correlation(ranked_scores, ranked_returns)
def compute_contribution_scores(
weights: list[float],
) -> list[float]:
"""Compute contribution scores from document weights.
Each score = weight_i / sum(weights). Sums to 1.0.
Each score in [0.0, 1.0].
Returns empty list for empty input.
"""
if not weights:
return []
total = sum(weights)
if total == 0.0:
n = len(weights)
return [1.0 / n] * n
return [w / total for w in weights]
def compute_hit_rate_improvement(win_rate: float) -> float:
"""Hit rate improvement over random 50/50 baseline.
Defined as (system_win_rate - 0.5) / 0.5.
"""
return (win_rate - 0.5) / 0.5
# ---------------------------------------------------------------------------
# SQL queries for v_prediction_performance view
# ---------------------------------------------------------------------------
_PERFORMANCE_DATA_SQL = """
SELECT
ticker,
direction,
action,
confidence,
strength,
p_bull,
score_company,
score_macro,
score_competitive,
future_return,
excess_return_vs_spy,
excess_return_vs_sector,
direction_correct,
profitable,
horizon,
generated_at
FROM v_prediction_performance
WHERE horizon = $1
"""
_PERFORMANCE_DATA_WITH_LOOKBACK_SQL = """
SELECT
ticker,
direction,
action,
confidence,
strength,
p_bull,
score_company,
score_macro,
score_competitive,
future_return,
excess_return_vs_spy,
excess_return_vs_sector,
direction_correct,
profitable,
horizon,
generated_at
FROM v_prediction_performance
WHERE horizon = $1
AND generated_at >= $2
"""
_INSERT_METRIC_SNAPSHOT_SQL = """
INSERT INTO model_metric_snapshots (
id, generated_at, lookback_window, horizon,
prediction_count, win_rate, directional_accuracy,
information_coefficient, rank_information_coefficient,
avg_return, avg_excess_return_vs_spy, avg_excess_return_vs_sector,
calibration_error, brier_score,
buy_win_rate, sell_win_rate, hold_win_rate,
metadata
) VALUES (
$1::uuid, $2, $3, $4,
$5, $6, $7,
$8, $9,
$10, $11, $12,
$13, $14,
$15, $16, $17,
$18::jsonb
)
"""
# ---------------------------------------------------------------------------
# Metric computation from raw rows
# ---------------------------------------------------------------------------
def _compute_metrics_from_rows(
rows: list[dict],
lookback_window: str,
horizon: str,
) -> ModelMetricSnapshot:
"""Compute all metrics from a list of prediction performance rows.
Returns a ModelMetricSnapshot with all computed metrics.
"""
now = datetime.now().astimezone()
snapshot_id = str(uuid.uuid4())
prediction_count = len(rows)
if prediction_count == 0:
return ModelMetricSnapshot(
id=snapshot_id,
generated_at=now,
lookback_window=lookback_window,
horizon=horizon,
prediction_count=0,
win_rate=0.0,
directional_accuracy=0.0,
information_coefficient=None,
rank_information_coefficient=None,
avg_return=0.0,
avg_excess_return_vs_spy=0.0,
avg_excess_return_vs_sector=0.0,
calibration_error=0.0,
brier_score=0.0,
buy_win_rate=0.0,
sell_win_rate=0.0,
hold_win_rate=0.0,
metadata={},
)
# --- Win rate and directional accuracy ---
direction_correct_count = sum(
1 for r in rows if r.get("direction_correct") is True
)
win_rate = direction_correct_count / prediction_count
directional_accuracy = win_rate # Same metric, different name
# --- Per-action win rates ---
buy_rows = [r for r in rows if (r.get("action") or "").lower() == "buy"]
sell_rows = [r for r in rows if (r.get("action") or "").lower() == "sell"]
hold_rows = [r for r in rows if (r.get("action") or "").lower() == "hold"]
buy_win_rate = (
sum(1 for r in buy_rows if r.get("direction_correct") is True) / len(buy_rows)
if buy_rows
else 0.0
)
sell_win_rate = (
sum(1 for r in sell_rows if r.get("direction_correct") is True)
/ len(sell_rows)
if sell_rows
else 0.0
)
hold_win_rate = (
sum(1 for r in hold_rows if r.get("direction_correct") is True)
/ len(hold_rows)
if hold_rows
else 0.0
)
# --- Average return ---
returns_list = [
r["future_return"] for r in rows if r.get("future_return") is not None
]
avg_return = sum(returns_list) / len(returns_list) if returns_list else 0.0
# --- Average excess return vs SPY (Requirement 9.1) ---
excess_spy_list = [
r["excess_return_vs_spy"]
for r in rows
if r.get("excess_return_vs_spy") is not None
]
avg_excess_return_vs_spy = (
sum(excess_spy_list) / len(excess_spy_list) if excess_spy_list else 0.0
)
# --- Average excess return vs sector ETF (Requirement 9.2) ---
excess_sector_list = [
r["excess_return_vs_sector"]
for r in rows
if r.get("excess_return_vs_sector") is not None
]
avg_excess_return_vs_sector = (
sum(excess_sector_list) / len(excess_sector_list)
if excess_sector_list
else 0.0
)
# --- Calibration error (ECE) (Requirements 5.1, 5.2, 5.3, 5.5) ---
confidences = [
r["confidence"] for r in rows if r.get("confidence") is not None
]
outcomes = [
r.get("direction_correct") is True
for r in rows
if r.get("confidence") is not None
]
ece, _buckets = compute_calibration_error(confidences, outcomes)
# --- Brier score (Requirement 5.4) ---
p_bulls = [r["p_bull"] for r in rows if r.get("p_bull") is not None]
brier_outcomes = [
r.get("direction_correct") is True
for r in rows
if r.get("p_bull") is not None
]
brier = compute_brier_score(p_bulls, brier_outcomes)
# --- Information Coefficient (Requirements 6.1, 6.5) ---
ic_scores = [
r["strength"] for r in rows if r.get("strength") is not None
and r.get("future_return") is not None
]
ic_returns = [
r["future_return"] for r in rows if r.get("strength") is not None
and r.get("future_return") is not None
]
ic = compute_information_coefficient(ic_scores, ic_returns)
# --- Rank Information Coefficient (Requirements 6.2, 6.5) ---
rank_ic = compute_rank_information_coefficient(ic_scores, ic_returns)
# --- Hit rate improvement (Requirement 9.4) ---
hit_rate_improvement = compute_hit_rate_improvement(win_rate)
# --- Metadata (Requirement 10.5) ---
metadata: dict = {
"hit_rate_improvement": hit_rate_improvement,
"buy_count": len(buy_rows),
"sell_count": len(sell_rows),
"hold_count": len(hold_rows),
"returns_count": len(returns_list),
"excess_spy_count": len(excess_spy_list),
"excess_sector_count": len(excess_sector_list),
}
return ModelMetricSnapshot(
id=snapshot_id,
generated_at=now,
lookback_window=lookback_window,
horizon=horizon,
prediction_count=prediction_count,
win_rate=win_rate,
directional_accuracy=directional_accuracy,
information_coefficient=ic,
rank_information_coefficient=rank_ic,
avg_return=avg_return,
avg_excess_return_vs_spy=avg_excess_return_vs_spy,
avg_excess_return_vs_sector=avg_excess_return_vs_sector,
calibration_error=ece,
brier_score=brier,
buy_win_rate=buy_win_rate,
sell_win_rate=sell_win_rate,
hold_win_rate=hold_win_rate,
metadata=metadata,
)
# ---------------------------------------------------------------------------
# Main entry point (Requirements 10.1, 10.2, 10.3, 10.4, 10.5)
# ---------------------------------------------------------------------------
async def compute_and_store_metric_snapshots(
pool: asyncpg.Pool,
) -> list[ModelMetricSnapshot]:
"""Compute metric snapshots for all lookback/horizon combinations.
Lookback windows: 7d, 30d, 90d, all-time.
Horizons: 1h, 6h, 1d, 7d, 30d.
For each of the 4 lookbacks × 5 horizons = 20 combinations, queries the
v_prediction_performance view, computes all metrics, and persists the
result to model_metric_snapshots.
Returns the list of computed snapshots.
"""
snapshots: list[ModelMetricSnapshot] = []
now = datetime.now().astimezone()
for lookback in LOOKBACK_WINDOWS:
duration = LOOKBACK_DURATIONS[lookback]
for horizon in EVALUATION_HORIZONS:
try:
# Query performance data
if duration is not None:
cutoff = now - duration
rows = await pool.fetch(
_PERFORMANCE_DATA_WITH_LOOKBACK_SQL,
horizon,
cutoff,
)
else:
rows = await pool.fetch(
_PERFORMANCE_DATA_SQL,
horizon,
)
# Convert asyncpg Records to dicts
row_dicts = [dict(r) for r in rows]
# Compute metrics
snapshot = _compute_metrics_from_rows(
row_dicts, lookback, horizon
)
# Persist
await pool.execute(
_INSERT_METRIC_SNAPSHOT_SQL,
snapshot.id,
snapshot.generated_at,
snapshot.lookback_window,
snapshot.horizon,
snapshot.prediction_count,
snapshot.win_rate,
snapshot.directional_accuracy,
snapshot.information_coefficient,
snapshot.rank_information_coefficient,
snapshot.avg_return,
snapshot.avg_excess_return_vs_spy,
snapshot.avg_excess_return_vs_sector,
snapshot.calibration_error,
snapshot.brier_score,
snapshot.buy_win_rate,
snapshot.sell_win_rate,
snapshot.hold_win_rate,
json.dumps(snapshot.metadata),
)
snapshots.append(snapshot)
except Exception:
logger.exception(
"Failed to compute metrics for lookback=%s horizon=%s",
lookback,
horizon,
)
continue
logger.info(
"Computed %d metric snapshots across %d lookback/horizon combinations",
len(snapshots),
len(LOOKBACK_WINDOWS) * len(EVALUATION_HORIZONS),
)
return snapshots
+414
View File
@@ -0,0 +1,414 @@
"""Outcome Evaluator — matches predictions with realized market outcomes.
Runs periodically to evaluate prediction snapshots whose horizon has elapsed.
For each snapshot, fetches future prices at the horizon endpoint and computes
returns, excess returns, directional accuracy, and profitability across all
five evaluation horizons (1h, 6h, 1d, 7d, 30d).
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 4.10
"""
from __future__ import annotations
import json
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timedelta
import asyncpg
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
HORIZON_DURATIONS: dict[str, timedelta] = {
"1h": timedelta(hours=1),
"6h": timedelta(hours=6),
"1d": timedelta(days=1),
"7d": timedelta(days=7),
"30d": timedelta(days=30),
}
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class PredictionOutcome:
"""Realized outcome for a prediction at a specific horizon."""
id: str # UUID
prediction_id: str
evaluated_at: datetime
horizon: str # 1h, 6h, 1d, 7d, 30d
future_price: float
future_return: float
spy_future_price: float | None
spy_return: float | None
sector_etf_future_price: float | None
sector_etf_return: float | None
excess_return_vs_spy: float | None
excess_return_vs_sector: float | None
direction_correct: bool
profitable: bool
metadata: dict = field(default_factory=dict)
# ---------------------------------------------------------------------------
# SQL statements
# ---------------------------------------------------------------------------
# Find matured predictions: snapshots where generated_at + horizon_duration <= NOW()
# and no outcome has been recorded yet for that (prediction_id, horizon) pair.
# We evaluate ALL 5 horizons for each snapshot, not just the snapshot's own horizon.
_MATURED_PREDICTIONS_SQL = """
SELECT
ps.id,
ps.generated_at,
ps.ticker,
ps.horizon AS snapshot_horizon,
ps.direction,
ps.action,
ps.price_at_prediction,
ps.spy_price_at_prediction,
ps.sector_etf_price_at_prediction
FROM prediction_snapshots ps
WHERE ps.generated_at + $1::interval <= NOW()
AND NOT EXISTS (
SELECT 1 FROM prediction_outcomes po
WHERE po.prediction_id = ps.id AND po.horizon = $2
)
"""
# Fetch the close price for a ticker at or before a specific time.
# Uses the closest bar before or at the target time.
_CLOSE_AT_TIME_SQL = """
SELECT (data->>'c')::float AS close
FROM market_snapshots
WHERE ticker = $1
AND snapshot_type = 'bar'
AND data->>'c' IS NOT NULL
AND captured_at <= $2
ORDER BY captured_at DESC
LIMIT 1
"""
_INSERT_OUTCOME_SQL = """
INSERT INTO prediction_outcomes (
id, prediction_id, evaluated_at, horizon,
future_price, future_return,
spy_future_price, spy_return,
sector_etf_future_price, sector_etf_return,
excess_return_vs_spy, excess_return_vs_sector,
direction_correct, profitable,
metadata
) VALUES (
$1::uuid, $2::uuid, $3, $4,
$5, $6,
$7, $8,
$9, $10,
$11, $12,
$13, $14,
$15::jsonb
)
"""
# ---------------------------------------------------------------------------
# Price fetching at a specific time
# ---------------------------------------------------------------------------
async def _fetch_close_at_time(
pool: asyncpg.Pool,
ticker: str,
target_time: datetime,
) -> float | None:
"""Fetch the close price for a ticker at or before a specific time.
Returns None if no market data is available before the target time.
"""
row = await pool.fetchrow(_CLOSE_AT_TIME_SQL, ticker, target_time)
if row is None:
return None
return row["close"]
# ---------------------------------------------------------------------------
# Sector ETF lookup (reuse pattern from prediction_snapshot)
# ---------------------------------------------------------------------------
_SECTOR_ETF_MAP: dict[str, str] = {
"Technology": "XLK",
"Consumer Cyclical": "XLY",
"Financial Services": "XLF",
"Healthcare": "XLV",
"Energy": "XLE",
"Communication Services": "XLC",
"Industrials": "XLI",
"Consumer Defensive": "XLP",
"Real Estate": "XLRE",
"Utilities": "XLU",
}
_COMPANY_SECTOR_SQL = """
SELECT sector FROM companies WHERE ticker = $1 AND active = TRUE LIMIT 1
"""
async def _fetch_sector_etf_ticker(pool: asyncpg.Pool, ticker: str) -> str | None:
"""Look up the sector ETF ticker for a company ticker."""
row = await pool.fetchrow(_COMPANY_SECTOR_SQL, ticker)
if row is None or row["sector"] is None:
return None
return _SECTOR_ETF_MAP.get(row["sector"])
# ---------------------------------------------------------------------------
# Return computation helpers
# ---------------------------------------------------------------------------
def _compute_return(current_price: float, future_price: float) -> float:
"""Compute simple return: (future - current) / current."""
if current_price == 0.0:
return 0.0
return (future_price - current_price) / current_price
def _is_direction_correct(direction: str, future_return: float) -> bool:
"""Determine if the predicted direction matches the realized return.
bullish + positive return = True
bearish + negative return = True
All other combinations = False
"""
direction_lower = direction.lower()
if direction_lower == "bullish" and future_return > 0.0:
return True
if direction_lower == "bearish" and future_return < 0.0:
return True
return False
def _is_profitable(action: str, future_return: float) -> bool:
"""Determine if the predicted action would have been profitable.
buy + positive return = True
sell + negative return = True
All other combinations = False
"""
action_lower = action.lower()
if action_lower == "buy" and future_return > 0.0:
return True
if action_lower == "sell" and future_return < 0.0:
return True
return False
# ---------------------------------------------------------------------------
# Single prediction evaluation (Requirements 4.24.7)
# ---------------------------------------------------------------------------
async def evaluate_single_prediction(
pool: asyncpg.Pool,
snapshot: dict,
horizon: str,
) -> PredictionOutcome | None:
"""Evaluate a single prediction at a specific horizon.
Fetches the future price at generated_at + horizon_duration for the ticker,
SPY, and sector ETF. Computes returns, excess returns, direction correctness,
and profitability.
Returns None if the ticker's future price is unavailable (Requirement 4.10).
"""
duration = HORIZON_DURATIONS[horizon]
target_time = snapshot["generated_at"] + duration
ticker = snapshot["ticker"]
# Fetch future price for the ticker — required (skip if unavailable)
future_price = await _fetch_close_at_time(pool, ticker, target_time)
if future_price is None:
logger.debug(
"Future price unavailable for %s at horizon %s (target %s), skipping",
ticker,
horizon,
target_time,
)
return None
price_at_prediction = snapshot["price_at_prediction"]
if price_at_prediction is None or price_at_prediction == 0.0:
logger.warning(
"Price at prediction is NULL or zero for snapshot %s, skipping horizon %s",
snapshot["id"],
horizon,
)
return None
# Compute ticker future return (Requirement 4.2)
future_return = _compute_return(price_at_prediction, future_price)
# Fetch SPY future price and compute SPY return (Requirement 4.3)
spy_future_price: float | None = None
spy_return: float | None = None
spy_price_at_prediction = snapshot["spy_price_at_prediction"]
if spy_price_at_prediction is not None and spy_price_at_prediction != 0.0:
spy_future_price = await _fetch_close_at_time(pool, "SPY", target_time)
if spy_future_price is not None:
spy_return = _compute_return(spy_price_at_prediction, spy_future_price)
# Fetch sector ETF future price and compute sector return (Requirement 4.4)
sector_etf_future_price: float | None = None
sector_etf_return: float | None = None
sector_etf_price_at_prediction = snapshot["sector_etf_price_at_prediction"]
if (
sector_etf_price_at_prediction is not None
and sector_etf_price_at_prediction != 0.0
):
sector_etf_ticker = await _fetch_sector_etf_ticker(pool, ticker)
if sector_etf_ticker is not None:
sector_etf_future_price = await _fetch_close_at_time(
pool, sector_etf_ticker, target_time
)
if sector_etf_future_price is not None:
sector_etf_return = _compute_return(
sector_etf_price_at_prediction, sector_etf_future_price
)
# Compute excess returns (Requirement 4.5)
excess_return_vs_spy: float | None = None
if future_return is not None and spy_return is not None:
excess_return_vs_spy = future_return - spy_return
excess_return_vs_sector: float | None = None
if future_return is not None and sector_etf_return is not None:
excess_return_vs_sector = future_return - sector_etf_return
# Determine direction correctness (Requirement 4.6)
direction_correct = _is_direction_correct(snapshot["direction"], future_return)
# Determine profitability (Requirement 4.7)
profitable = _is_profitable(snapshot["action"], future_return)
now = datetime.now().astimezone()
return PredictionOutcome(
id=str(uuid.uuid4()),
prediction_id=str(snapshot["id"]),
evaluated_at=now,
horizon=horizon,
future_price=future_price,
future_return=future_return,
spy_future_price=spy_future_price,
spy_return=spy_return,
sector_etf_future_price=sector_etf_future_price,
sector_etf_return=sector_etf_return,
excess_return_vs_spy=excess_return_vs_spy,
excess_return_vs_sector=excess_return_vs_sector,
direction_correct=direction_correct,
profitable=profitable,
metadata={
"ticker": ticker,
"horizon": horizon,
"price_at_prediction": price_at_prediction,
"future_price": future_price,
},
)
# ---------------------------------------------------------------------------
# Store outcome (Requirement 4.9)
# ---------------------------------------------------------------------------
async def _store_outcome(
conn: asyncpg.Connection,
outcome: PredictionOutcome,
) -> None:
"""Persist a single prediction outcome to the database."""
await conn.execute(
_INSERT_OUTCOME_SQL,
outcome.id,
outcome.prediction_id,
outcome.evaluated_at,
outcome.horizon,
outcome.future_price,
outcome.future_return,
outcome.spy_future_price,
outcome.spy_return,
outcome.sector_etf_future_price,
outcome.sector_etf_return,
outcome.excess_return_vs_spy,
outcome.excess_return_vs_sector,
outcome.direction_correct,
outcome.profitable,
json.dumps(outcome.metadata),
)
# ---------------------------------------------------------------------------
# Main entry point (Requirements 4.1, 4.8, 4.9, 4.10)
# ---------------------------------------------------------------------------
async def evaluate_matured_predictions(
pool: asyncpg.Pool,
) -> int:
"""Evaluate all matured prediction snapshots across all horizons.
For each of the 5 horizons (1h, 6h, 1d, 7d, 30d), finds prediction
snapshots where generated_at + horizon_duration <= NOW() and no outcome
has been recorded for that (prediction_id, horizon) pair.
For each matured snapshot-horizon pair, fetches future prices and computes
returns. Skips horizons where the future price is unavailable — those will
be retried on the next run (Requirement 4.10).
Returns the total count of outcomes recorded.
"""
total_recorded = 0
for horizon, duration in HORIZON_DURATIONS.items():
# Find snapshots matured for this horizon
rows = await pool.fetch(_MATURED_PREDICTIONS_SQL, duration, horizon)
if not rows:
continue
logger.info(
"Found %d matured predictions for horizon %s", len(rows), horizon
)
for row in rows:
snapshot = dict(row)
try:
outcome = await evaluate_single_prediction(pool, snapshot, horizon)
if outcome is None:
# Future price unavailable — skip, retry next run
continue
async with pool.acquire() as conn:
async with conn.transaction():
await _store_outcome(conn, outcome)
total_recorded += 1
except Exception:
logger.exception(
"Failed to evaluate snapshot %s at horizon %s",
snapshot["id"],
horizon,
)
continue
logger.info("Outcome evaluation complete: %d outcomes recorded", total_recorded)
return total_recorded
+540
View File
@@ -0,0 +1,540 @@
"""Prediction Snapshot Writer — captures immutable prediction state at generation time.
Creates frozen records of every recommendation with prices, evidence links,
duplicate detection, and contribution scores so that predictions can be
evaluated against future outcomes without hindsight bias.
Requirements: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 3.1, 3.2, 3.3, 3.4
"""
from __future__ import annotations
import hashlib
import json
import logging
import urllib.parse
import uuid
from dataclasses import dataclass, field
from datetime import datetime
import asyncpg
from services.shared.schemas import Recommendation, TrendSummary
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SECTOR_ETF_MAP: dict[str, str] = {
"Technology": "XLK",
"Consumer Cyclical": "XLY",
"Financial Services": "XLF",
"Healthcare": "XLV",
"Energy": "XLE",
"Communication Services": "XLC",
"Industrials": "XLI",
"Consumer Defensive": "XLP",
"Real Estate": "XLRE",
"Utilities": "XLU",
}
EVALUATION_HORIZONS: list[str] = ["1h", "6h", "1d", "7d", "30d"]
MAX_SINGLE_DOCUMENT_WEIGHT: float = 1.0
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class PredictionSnapshot:
"""Immutable snapshot of a prediction at generation time."""
id: str # UUID
generated_at: datetime
ticker: str
window: str
horizon: str
direction: str # bullish/bearish/mixed/neutral
action: str # buy/sell/hold/watch
mode: str # informational/paper_eligible/live_eligible
strength: float
confidence: float
contradiction: float
p_bull: float | None
p_bear: float | None
score_company: float
score_macro: float
score_competitive: float
evidence_count: int
unique_source_count: int
duplicate_evidence_count: int
price_at_prediction: float | None
spy_price_at_prediction: float | None
sector_etf_price_at_prediction: float | None
metadata: dict = field(default_factory=dict)
@dataclass
class SignalEvidenceLink:
"""Link between a prediction and a contributing evidence document."""
id: str # UUID
prediction_id: str
document_id: str
signal_id: str
ticker: str
source: str
source_type: str
catalyst_type: str
sentiment: str
impact: float
extraction_confidence: float
weight: float # clamped to MAX_SINGLE_DOCUMENT_WEIGHT
is_duplicate: bool
canonical_evidence_key: str
contribution_score: float # weight / total_weight, sums to 1.0
metadata: dict = field(default_factory=dict)
# ---------------------------------------------------------------------------
# Canonical evidence key computation (Requirements 2.3, 17.4)
# ---------------------------------------------------------------------------
def compute_canonical_evidence_key(title: str, url: str) -> str:
"""SHA256 of normalized(title) + normalized(url).
Normalization:
- Title: lowercase, strip leading/trailing whitespace
- URL: lowercase, strip query parameters (keep scheme, netloc, path)
"""
normalized_title = title.strip().lower()
parsed = urllib.parse.urlparse(url.lower())
normalized_url = urllib.parse.urlunparse(
(parsed.scheme, parsed.netloc, parsed.path, "", "", "")
)
combined = normalized_title + normalized_url
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
# ---------------------------------------------------------------------------
# Contribution score computation (Requirements 2.5, 17.7)
# ---------------------------------------------------------------------------
def compute_contribution_scores(weights: list[float]) -> list[float]:
"""Compute contribution scores: each score = weight_i / sum(weights).
All scores are in [0.0, 1.0] and sum to 1.0 (within floating-point tolerance).
Returns an empty list for empty input.
"""
if not weights:
return []
total = sum(weights)
if total == 0.0:
# All weights are zero — distribute equally
n = len(weights)
return [1.0 / n] * n
return [w / total for w in weights]
# ---------------------------------------------------------------------------
# Price fetching (Requirements 1.2, 1.3, 1.4, 1.5)
# ---------------------------------------------------------------------------
_LATEST_CLOSE_SQL = """
SELECT (data->>'c')::float AS close
FROM market_snapshots
WHERE ticker = $1 AND snapshot_type = 'bar' AND data->>'c' IS NOT NULL
ORDER BY captured_at DESC
LIMIT 1
"""
async def fetch_latest_close_price(
pool: asyncpg.Pool,
ticker: str,
) -> float | None:
"""Fetch most recent close price from market_snapshots for a ticker.
Returns None if no market data is available for the ticker.
"""
row = await pool.fetchrow(_LATEST_CLOSE_SQL, ticker)
if row is None:
return None
return row["close"]
# ---------------------------------------------------------------------------
# Sector ETF lookup
# ---------------------------------------------------------------------------
_COMPANY_SECTOR_SQL = """
SELECT sector FROM companies WHERE ticker = $1 AND active = TRUE LIMIT 1
"""
async def _fetch_sector_etf_ticker(pool: asyncpg.Pool, ticker: str) -> str | None:
"""Look up the sector ETF ticker for a company ticker."""
row = await pool.fetchrow(_COMPANY_SECTOR_SQL, ticker)
if row is None or row["sector"] is None:
return None
return SECTOR_ETF_MAP.get(row["sector"])
# ---------------------------------------------------------------------------
# Layer score computation
# ---------------------------------------------------------------------------
def _compute_layer_scores(
evidence_signals: list[dict],
) -> tuple[float, float, float]:
"""Compute company, macro, and competitive layer scores from evidence signals.
Each signal's source_type determines its layer:
- company: news_api, filings_api, web_scrape
- macro: macro events (source_type containing 'macro')
- competitive: competitive signals (source_type containing 'competitive' or 'pattern')
Returns (score_company, score_macro, score_competitive) as fractions summing to 1.0.
"""
company_weight = 0.0
macro_weight = 0.0
competitive_weight = 0.0
for sig in evidence_signals:
w = sig.get("weight", 0.0)
source_type = sig.get("source_type", "").lower()
catalyst_type = sig.get("catalyst_type", "").lower()
if "macro" in source_type or catalyst_type == "macro":
macro_weight += w
elif "competitive" in source_type or "pattern" in source_type:
competitive_weight += w
else:
company_weight += w
total = company_weight + macro_weight + competitive_weight
if total == 0.0:
return (0.0, 0.0, 0.0)
return (
round(company_weight / total, 6),
round(macro_weight / total, 6),
round(competitive_weight / total, 6),
)
# ---------------------------------------------------------------------------
# SQL statements
# ---------------------------------------------------------------------------
_INSERT_SNAPSHOT_SQL = """
INSERT INTO prediction_snapshots (
id, generated_at, ticker, window, horizon, direction, action, mode,
strength, confidence, contradiction, p_bull, p_bear,
score_company, score_macro, score_competitive,
evidence_count, unique_source_count, duplicate_evidence_count,
price_at_prediction, spy_price_at_prediction, sector_etf_price_at_prediction,
metadata
) VALUES (
$1::uuid, $2, $3, $4, $5, $6, $7, $8,
$9, $10, $11, $12, $13,
$14, $15, $16,
$17, $18, $19,
$20, $21, $22,
$23::jsonb
)
"""
_INSERT_EVIDENCE_LINK_SQL = """
INSERT INTO signal_evidence_links (
id, prediction_id, document_id, signal_id, ticker,
source, source_type, catalyst_type, sentiment,
impact, extraction_confidence, weight,
is_duplicate, canonical_evidence_key, contribution_score,
metadata
) VALUES (
$1::uuid, $2::uuid, $3, $4, $5,
$6, $7, $8, $9,
$10, $11, $12,
$13, $14, $15,
$16::jsonb
)
"""
# ---------------------------------------------------------------------------
# Main entry point (Requirements 1.11.7, 2.12.6, 3.13.4)
# ---------------------------------------------------------------------------
async def create_prediction_snapshot(
pool: asyncpg.Pool,
recommendation: Recommendation,
trend_summary: TrendSummary,
evidence_signals: list[dict],
evidence_docs: list[dict],
) -> PredictionSnapshot:
"""Create and persist a prediction snapshot with evidence links.
Steps:
1. Fetch current prices (ticker, SPY, sector ETF) from market_snapshots
2. Compute canonical evidence keys and detect duplicates
3. Clamp individual document weights to MAX_SINGLE_DOCUMENT_WEIGHT
4. Compute contribution scores (one-vote-per-canonical-key dedup)
5. Persist snapshot and evidence links in a transaction
Args:
pool: asyncpg connection pool.
recommendation: The generated Recommendation object.
trend_summary: The TrendSummary used to generate the recommendation.
evidence_signals: List of dicts with signal fields (source, source_type,
catalyst_type, sentiment, impact, extraction_confidence, weight,
document_id, signal_id, ticker).
evidence_docs: List of dicts with document metadata (title, url, document_id).
Returns:
The persisted PredictionSnapshot.
"""
ticker = recommendation.ticker
# 1. Fetch prices — handle NULL gracefully (Requirement 1.5)
ticker_price = await fetch_latest_close_price(pool, ticker)
if ticker_price is None:
logger.warning("No market price available for %s at snapshot time", ticker)
spy_price = await fetch_latest_close_price(pool, "SPY")
if spy_price is None:
logger.warning("No SPY price available at snapshot time")
sector_etf_ticker = await _fetch_sector_etf_ticker(pool, ticker)
sector_etf_price: float | None = None
if sector_etf_ticker is not None:
sector_etf_price = await fetch_latest_close_price(pool, sector_etf_ticker)
if sector_etf_price is None:
logger.warning(
"No sector ETF price available for %s (%s) at snapshot time",
sector_etf_ticker,
ticker,
)
else:
logger.warning("No sector ETF mapping found for ticker %s", ticker)
# 2. Build a doc lookup for canonical key computation
doc_lookup: dict[str, dict] = {}
for doc in evidence_docs:
doc_id = doc.get("document_id", "")
doc_lookup[doc_id] = doc
# 3. Process evidence signals: compute canonical keys, detect duplicates,
# clamp weights
processed_links: list[dict] = []
seen_canonical_keys: dict[str, int] = {} # canonical_key -> first index
for sig in evidence_signals:
doc_id = sig.get("document_id", "")
doc_meta = doc_lookup.get(doc_id, {})
title = doc_meta.get("title", "")
url = doc_meta.get("url", "")
canonical_key = compute_canonical_evidence_key(title, url)
# Detect duplicates: same canonical key for same ticker
is_duplicate = canonical_key in seen_canonical_keys
if not is_duplicate:
seen_canonical_keys[canonical_key] = len(processed_links)
# Clamp weight to MAX_SINGLE_DOCUMENT_WEIGHT (Requirement 3.3)
raw_weight = sig.get("weight", 0.0)
clamped_weight = min(raw_weight, MAX_SINGLE_DOCUMENT_WEIGHT)
processed_links.append({
"id": str(uuid.uuid4()),
"document_id": doc_id,
"signal_id": sig.get("signal_id", ""),
"ticker": sig.get("ticker", ticker),
"source": sig.get("source", ""),
"source_type": sig.get("source_type", ""),
"catalyst_type": sig.get("catalyst_type", ""),
"sentiment": sig.get("sentiment", ""),
"impact": sig.get("impact", 0.0),
"extraction_confidence": sig.get("extraction_confidence", 0.0),
"weight": clamped_weight,
"is_duplicate": is_duplicate,
"canonical_evidence_key": canonical_key,
})
# 4. Compute contribution scores — one vote per canonical key (Requirement 3.4)
# Only non-duplicate links contribute to the weight pool
non_dup_weights = [
link["weight"] for link in processed_links if not link["is_duplicate"]
]
non_dup_scores = compute_contribution_scores(non_dup_weights)
# Assign contribution scores: non-duplicates get their computed score,
# duplicates get 0.0
score_idx = 0
for link in processed_links:
if not link["is_duplicate"]:
link["contribution_score"] = non_dup_scores[score_idx]
score_idx += 1
else:
link["contribution_score"] = 0.0
# 5. Compute deduplication quality metrics (Requirements 3.1, 3.2)
unique_sources = {
link["source"]
for link in processed_links
if not link["is_duplicate"]
}
unique_source_count = len(unique_sources)
duplicate_evidence_count = sum(
1 for link in processed_links if link["is_duplicate"]
)
# 6. Compute layer scores from evidence signals
score_company, score_macro, score_competitive = _compute_layer_scores(
evidence_signals
)
# 7. Build metadata from trend summary context (Requirement 1.7)
metadata: dict = {}
if trend_summary.market_context is not None:
metadata["market_context"] = {
"ticker": trend_summary.market_context.ticker,
"price_change_pct": trend_summary.market_context.price_change_pct,
"avg_volume": trend_summary.market_context.avg_volume,
"volume_change_pct": trend_summary.market_context.volume_change_pct,
"volatility": trend_summary.market_context.volatility,
"latest_close": trend_summary.market_context.latest_close,
"bars_available": trend_summary.market_context.bars_available,
}
if sector_etf_ticker is not None:
metadata["sector_etf_ticker"] = sector_etf_ticker
# 8. Build the snapshot
snapshot_id = str(uuid.uuid4())
snapshot = PredictionSnapshot(
id=snapshot_id,
generated_at=recommendation.generated_at,
ticker=ticker,
window=trend_summary.window.value,
horizon=recommendation.time_horizon,
direction=trend_summary.trend_direction.value,
action=recommendation.action.value,
mode=recommendation.mode.value,
strength=trend_summary.trend_strength,
confidence=recommendation.confidence,
contradiction=trend_summary.contradiction_score,
p_bull=trend_summary.p_bull,
p_bear=1.0 - trend_summary.p_bull if trend_summary.p_bull is not None else None,
score_company=score_company,
score_macro=score_macro,
score_competitive=score_competitive,
evidence_count=len(processed_links),
unique_source_count=unique_source_count,
duplicate_evidence_count=duplicate_evidence_count,
price_at_prediction=ticker_price,
spy_price_at_prediction=spy_price,
sector_etf_price_at_prediction=sector_etf_price,
metadata=metadata,
)
# 9. Build evidence link objects
evidence_link_objects: list[SignalEvidenceLink] = []
for link in processed_links:
evidence_link_objects.append(
SignalEvidenceLink(
id=link["id"],
prediction_id=snapshot_id,
document_id=link["document_id"],
signal_id=link["signal_id"],
ticker=link["ticker"],
source=link["source"],
source_type=link["source_type"],
catalyst_type=link["catalyst_type"],
sentiment=link["sentiment"],
impact=link["impact"],
extraction_confidence=link["extraction_confidence"],
weight=link["weight"],
is_duplicate=link["is_duplicate"],
canonical_evidence_key=link["canonical_evidence_key"],
contribution_score=link["contribution_score"],
)
)
# 10. Persist in a transaction (Requirements 1.6, 2.6)
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(
_INSERT_SNAPSHOT_SQL,
snapshot.id,
snapshot.generated_at,
snapshot.ticker,
snapshot.window,
snapshot.horizon,
snapshot.direction,
snapshot.action,
snapshot.mode,
snapshot.strength,
snapshot.confidence,
snapshot.contradiction,
snapshot.p_bull,
snapshot.p_bear,
snapshot.score_company,
snapshot.score_macro,
snapshot.score_competitive,
snapshot.evidence_count,
snapshot.unique_source_count,
snapshot.duplicate_evidence_count,
snapshot.price_at_prediction,
snapshot.spy_price_at_prediction,
snapshot.sector_etf_price_at_prediction,
json.dumps(snapshot.metadata),
)
for link in evidence_link_objects:
await conn.execute(
_INSERT_EVIDENCE_LINK_SQL,
link.id,
link.prediction_id,
link.document_id,
link.signal_id,
link.ticker,
link.source,
link.source_type,
link.catalyst_type,
link.sentiment,
link.impact,
link.extraction_confidence,
link.weight,
link.is_duplicate,
link.canonical_evidence_key,
link.contribution_score,
json.dumps(link.metadata),
)
logger.info(
"Created prediction snapshot %s for %s: %d evidence links "
"(%d unique sources, %d duplicates), prices: ticker=%s spy=%s sector_etf=%s",
snapshot_id,
ticker,
len(evidence_link_objects),
unique_source_count,
duplicate_evidence_count,
ticker_price,
spy_price,
sector_etf_price,
)
return snapshot