feat: model validation, calibration, and signal quality layer
ci/woodpecker/push/test Pipeline failed
ci/woodpecker/push/build-1 unknown status
ci/woodpecker/push/build-3 unknown status
ci/woodpecker/push/build-2 unknown status
ci/woodpecker/push/finalize unknown status
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
ci/woodpecker/push/test Pipeline failed
ci/woodpecker/push/build-1 unknown status
ci/woodpecker/push/build-3 unknown status
ci/woodpecker/push/build-2 unknown status
ci/woodpecker/push/finalize unknown status
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
- Migration 035: prediction_snapshots, prediction_outcomes, signal_evidence_links, model_metric_snapshots tables + SQL views - Prediction snapshot writer with canonical evidence keys, duplicate detection, contribution scores - Outcome evaluator across 5 horizons (1h, 6h, 1d, 7d, 30d) - Metrics engine: ECE, Brier score, IC, Rank IC, benchmark comparison - Attribution engine: per-source, per-catalyst, per-layer performance - Calibration engine: Bayesian shrinkage source reliability - Quality gate for live trading eligibility with configurable thresholds - 7 new /api/validation/* endpoints - Upgraded OpsModel dashboard with validation tab - Enhanced recommendation display with calibration context - Backtest replay validation mode - 86 Python tests (unit + property-based), 179 frontend tests passing
This commit is contained in:
@@ -64,6 +64,7 @@ from services.shared.metrics import (
|
||||
AGGREGATION_WINDOWS_COMPUTED,
|
||||
)
|
||||
from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow
|
||||
from services.trading.model_quality_gate import QualityGateResult, evaluate_quality_gate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1576,10 +1577,34 @@ async def aggregate_company(
|
||||
# Mid-cycle changes take effect on the next cycle.
|
||||
probabilistic = await fetch_probabilistic_scoring_enabled(pool)
|
||||
pipeline_mode = "probabilistic" if probabilistic else "heuristic"
|
||||
|
||||
# --- Quality gate evaluation (Req 11.2, 11.3) ---
|
||||
# Evaluate model quality gate at the start of each aggregation cycle.
|
||||
# When the gate fails, all recommendations are forced to paper mode.
|
||||
# Gate evaluation failure defaults to paper-only (fail-safe).
|
||||
quality_gate_passed = False
|
||||
try:
|
||||
gate_result: QualityGateResult = await evaluate_quality_gate(pool)
|
||||
quality_gate_passed = gate_result.passed
|
||||
logger.info(
|
||||
"Quality gate for %s cycle: %s — %s",
|
||||
ticker,
|
||||
"PASS" if gate_result.passed else "FAIL",
|
||||
gate_result.reason,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Quality gate evaluation failed for %s cycle — "
|
||||
"defaulting to paper-only mode (fail-safe)",
|
||||
ticker,
|
||||
)
|
||||
quality_gate_passed = False
|
||||
|
||||
logger.info(
|
||||
"Aggregation cycle for %s: pipeline_mode=%s",
|
||||
"Aggregation cycle for %s: pipeline_mode=%s quality_gate=%s",
|
||||
ticker,
|
||||
pipeline_mode,
|
||||
"passed" if quality_gate_passed else "failed",
|
||||
)
|
||||
|
||||
# --- Regime detection (Req 7.1, 7.2, 7.3, 7.8, 7.9) ---
|
||||
@@ -1647,6 +1672,20 @@ async def aggregate_company(
|
||||
ticker_returns=ticker_returns,
|
||||
ticker_volumes=ticker_volumes,
|
||||
)
|
||||
|
||||
# When quality gate fails, annotate the trend summary so the
|
||||
# recommendation engine forces paper mode (Req 11.2, 11.3).
|
||||
if not quality_gate_passed:
|
||||
ctx = summary.market_context
|
||||
if isinstance(ctx, dict):
|
||||
ctx["quality_gate_passed"] = False
|
||||
elif ctx is not None and hasattr(ctx, "model_dump"):
|
||||
ctx_dict = ctx.model_dump()
|
||||
ctx_dict["quality_gate_passed"] = False
|
||||
summary.market_context = ctx_dict
|
||||
else:
|
||||
summary.market_context = {"quality_gate_passed": False}
|
||||
|
||||
summaries.append(summary)
|
||||
|
||||
return summaries
|
||||
|
||||
@@ -43,6 +43,11 @@ from services.shared.db import get_pg_pool, get_redis
|
||||
from services.shared.logging import new_trace_id, set_trace_context, setup_logging
|
||||
from services.shared.redis_keys import PIPELINE_ENABLED_KEY, QUEUE_BROKER, QUEUE_PREFIX, queue_key
|
||||
from services.shared.schemas import MAJOR_DECISION_CATALYSTS
|
||||
from services.validation.attribution import (
|
||||
compute_catalyst_attribution,
|
||||
compute_layer_attribution,
|
||||
compute_source_attribution,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("query_api")
|
||||
|
||||
@@ -3769,3 +3774,336 @@ async def get_variant_performance_history(
|
||||
agent_id, variant_id, hours,
|
||||
)
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model Validation Dashboard (Requirements 12.1, 12.2, 12.3, 12.7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_VALID_LOOKBACKS = {"7d", "30d", "90d", "all"}
|
||||
_VALID_HORIZONS = {"1h", "6h", "1d", "7d", "30d"}
|
||||
|
||||
|
||||
@app.get("/api/validation/summary")
|
||||
async def get_validation_summary(
|
||||
lookback: str = Query(default="30d"),
|
||||
horizon: str = Query(default="7d"),
|
||||
):
|
||||
"""Latest model metric snapshot plus quality gate status.
|
||||
|
||||
Returns the most recent model_metric_snapshot for the given
|
||||
lookback/horizon combination, along with the current gate status
|
||||
from risk_configs.
|
||||
|
||||
Requirement 12.1
|
||||
"""
|
||||
if lookback not in _VALID_LOOKBACKS:
|
||||
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
|
||||
if horizon not in _VALID_HORIZONS:
|
||||
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
|
||||
|
||||
# Latest metric snapshot for the requested lookback/horizon
|
||||
snapshot_row = await pool.fetchrow(
|
||||
"""SELECT id, generated_at, lookback_window, horizon,
|
||||
prediction_count, win_rate, directional_accuracy,
|
||||
information_coefficient, rank_information_coefficient,
|
||||
avg_return, avg_excess_return_vs_spy, avg_excess_return_vs_sector,
|
||||
calibration_error, brier_score,
|
||||
buy_win_rate, sell_win_rate, hold_win_rate,
|
||||
metadata
|
||||
FROM model_metric_snapshots
|
||||
WHERE lookback_window = $1 AND horizon = $2
|
||||
ORDER BY generated_at DESC
|
||||
LIMIT 1""",
|
||||
lookback, horizon,
|
||||
)
|
||||
|
||||
snapshot = None
|
||||
if snapshot_row:
|
||||
snapshot = _row_to_dict(snapshot_row)
|
||||
snapshot["metadata"] = _parse_jsonb(snapshot.get("metadata"))
|
||||
|
||||
# Gate status from risk_configs
|
||||
gate_row = await pool.fetchrow(
|
||||
"SELECT config, updated_at FROM risk_configs WHERE name = 'model_quality_gate'",
|
||||
)
|
||||
gate_status = None
|
||||
if gate_row:
|
||||
gate_status = _parse_jsonb(gate_row["config"])
|
||||
|
||||
return {
|
||||
"snapshot": snapshot,
|
||||
"gate_status": gate_status,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/validation/calibration")
|
||||
async def get_validation_calibration(
|
||||
lookback: str = Query(default="30d"),
|
||||
horizon: str = Query(default="7d"),
|
||||
):
|
||||
"""Calibration table with confidence buckets.
|
||||
|
||||
Queries v_prediction_performance for the given lookback/horizon,
|
||||
groups by confidence buckets, and computes avg_confidence,
|
||||
observed_win_rate, count, and miscalibrated flag per bucket.
|
||||
|
||||
Requirement 12.2
|
||||
"""
|
||||
if lookback not in _VALID_LOOKBACKS:
|
||||
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
|
||||
if horizon not in _VALID_HORIZONS:
|
||||
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
|
||||
|
||||
# Build lookback filter
|
||||
lookback_condition = ""
|
||||
params: list[Any] = [horizon]
|
||||
idx = 2
|
||||
|
||||
if lookback != "all":
|
||||
lookback_days = {"7d": 7, "30d": 30, "90d": 90}[lookback]
|
||||
lookback_condition = f"AND generated_at >= NOW() - make_interval(days => ${idx})"
|
||||
params.append(lookback_days)
|
||||
idx += 1
|
||||
|
||||
rows = await pool.fetch(
|
||||
f"""SELECT confidence, direction_correct
|
||||
FROM v_prediction_performance
|
||||
WHERE horizon = $1
|
||||
{lookback_condition}
|
||||
AND confidence IS NOT NULL""",
|
||||
*params,
|
||||
)
|
||||
|
||||
# Group into calibration buckets
|
||||
buckets_def = [
|
||||
(0.50, 0.60),
|
||||
(0.60, 0.70),
|
||||
(0.70, 0.80),
|
||||
(0.80, 0.90),
|
||||
(0.90, 1.00),
|
||||
]
|
||||
|
||||
buckets = []
|
||||
for low, high in buckets_def:
|
||||
bucket_rows = []
|
||||
for r in rows:
|
||||
conf = float(r["confidence"])
|
||||
if high == 1.00:
|
||||
in_bucket = low <= conf <= high
|
||||
else:
|
||||
in_bucket = low <= conf < high
|
||||
if in_bucket:
|
||||
bucket_rows.append(r)
|
||||
|
||||
count = len(bucket_rows)
|
||||
if count == 0:
|
||||
buckets.append({
|
||||
"bucket_low": low,
|
||||
"bucket_high": high,
|
||||
"avg_confidence": 0.0,
|
||||
"observed_win_rate": 0.0,
|
||||
"prediction_count": 0,
|
||||
"miscalibrated": False,
|
||||
})
|
||||
continue
|
||||
|
||||
avg_conf = sum(float(r["confidence"]) for r in bucket_rows) / count
|
||||
win_count = sum(1 for r in bucket_rows if r["direction_correct"] is True)
|
||||
win_rate = win_count / count
|
||||
diff = abs(avg_conf - win_rate)
|
||||
|
||||
buckets.append({
|
||||
"bucket_low": low,
|
||||
"bucket_high": high,
|
||||
"avg_confidence": round(avg_conf, 4),
|
||||
"observed_win_rate": round(win_rate, 4),
|
||||
"prediction_count": count,
|
||||
"miscalibrated": diff > 0.15,
|
||||
})
|
||||
|
||||
return {"buckets": buckets, "lookback": lookback, "horizon": horizon}
|
||||
|
||||
|
||||
@app.get("/api/validation/ic-by-horizon")
|
||||
async def get_validation_ic_by_horizon(
|
||||
lookback: str = Query(default="30d"),
|
||||
):
|
||||
"""IC and Rank IC per prediction horizon.
|
||||
|
||||
Queries the most recent model_metric_snapshot for the given lookback
|
||||
across all 5 horizons, returning IC and Rank IC for each.
|
||||
|
||||
Requirement 12.3
|
||||
"""
|
||||
if lookback not in _VALID_LOOKBACKS:
|
||||
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
|
||||
|
||||
rows = await pool.fetch(
|
||||
"""SELECT DISTINCT ON (horizon)
|
||||
horizon,
|
||||
information_coefficient,
|
||||
rank_information_coefficient,
|
||||
prediction_count,
|
||||
generated_at
|
||||
FROM model_metric_snapshots
|
||||
WHERE lookback_window = $1
|
||||
ORDER BY horizon, generated_at DESC""",
|
||||
lookback,
|
||||
)
|
||||
|
||||
horizons = []
|
||||
for r in rows:
|
||||
horizons.append({
|
||||
"horizon": r["horizon"],
|
||||
"information_coefficient": float(r["information_coefficient"]) if r["information_coefficient"] is not None else None,
|
||||
"rank_information_coefficient": float(r["rank_information_coefficient"]) if r["rank_information_coefficient"] is not None else None,
|
||||
"prediction_count": r["prediction_count"],
|
||||
"generated_at": r["generated_at"].isoformat() if r["generated_at"] else None,
|
||||
})
|
||||
|
||||
# Sort by canonical horizon order
|
||||
horizon_order = {"1h": 0, "6h": 1, "1d": 2, "7d": 3, "30d": 4}
|
||||
horizons.sort(key=lambda h: horizon_order.get(h["horizon"], 99))
|
||||
|
||||
return {"horizons": horizons, "lookback": lookback}
|
||||
|
||||
|
||||
@app.get("/api/validation/gate-status")
|
||||
async def get_validation_gate_status():
|
||||
"""Quality gate evaluation detail.
|
||||
|
||||
Returns the stored gate evaluation result from risk_configs
|
||||
where key = 'model_quality_gate'.
|
||||
|
||||
Requirement 12.7
|
||||
"""
|
||||
gate_row = await pool.fetchrow(
|
||||
"SELECT config, updated_at FROM risk_configs WHERE name = 'model_quality_gate'",
|
||||
)
|
||||
|
||||
if not gate_row:
|
||||
return {
|
||||
"gate_status": None,
|
||||
"message": "No gate evaluation found. Model metrics may not have been computed yet.",
|
||||
}
|
||||
|
||||
gate_data = _parse_jsonb(gate_row["config"])
|
||||
updated_at = gate_row["updated_at"].isoformat() if gate_row.get("updated_at") else None
|
||||
|
||||
return {
|
||||
"gate_status": gate_data,
|
||||
"updated_at": updated_at,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attribution Endpoints (Requirements 12.4, 12.5, 12.6)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_LOOKBACK_TO_DAYS: dict[str, int] = {
|
||||
"7d": 7,
|
||||
"30d": 30,
|
||||
"90d": 90,
|
||||
"all": 3650,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/validation/attribution/sources")
|
||||
async def get_validation_attribution_sources(
|
||||
lookback: str = Query(default="30d"),
|
||||
horizon: str = Query(default="7d"),
|
||||
):
|
||||
"""Per-source performance metrics.
|
||||
|
||||
Returns win rate, IC, average return, duplicate rate, and other
|
||||
attribution metrics for each source, computed over the given
|
||||
lookback window and prediction horizon.
|
||||
|
||||
Requirement 12.4
|
||||
"""
|
||||
if lookback not in _VALID_LOOKBACKS:
|
||||
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
|
||||
if horizon not in _VALID_HORIZONS:
|
||||
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
|
||||
|
||||
lookback_days = _LOOKBACK_TO_DAYS[lookback]
|
||||
|
||||
try:
|
||||
results = await compute_source_attribution(pool, lookback_days=lookback_days, horizon=horizon)
|
||||
except Exception:
|
||||
logger.exception("Failed to compute source attribution")
|
||||
raise HTTPException(500, "Failed to compute source attribution")
|
||||
|
||||
return {
|
||||
"sources": [asdict(r) for r in results],
|
||||
"lookback": lookback,
|
||||
"horizon": horizon,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/validation/attribution/catalysts")
|
||||
async def get_validation_attribution_catalysts(
|
||||
lookback: str = Query(default="30d"),
|
||||
horizon: str = Query(default="7d"),
|
||||
):
|
||||
"""Per-catalyst-type performance metrics.
|
||||
|
||||
Returns win rate, IC, average return, and other attribution metrics
|
||||
for each catalyst type, computed over the given lookback window
|
||||
and prediction horizon.
|
||||
|
||||
Requirement 12.5
|
||||
"""
|
||||
if lookback not in _VALID_LOOKBACKS:
|
||||
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
|
||||
if horizon not in _VALID_HORIZONS:
|
||||
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
|
||||
|
||||
lookback_days = _LOOKBACK_TO_DAYS[lookback]
|
||||
|
||||
try:
|
||||
results = await compute_catalyst_attribution(pool, lookback_days=lookback_days, horizon=horizon)
|
||||
except Exception:
|
||||
logger.exception("Failed to compute catalyst attribution")
|
||||
raise HTTPException(500, "Failed to compute catalyst attribution")
|
||||
|
||||
return {
|
||||
"catalysts": [asdict(r) for r in results],
|
||||
"lookback": lookback,
|
||||
"horizon": horizon,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/validation/attribution/layers")
|
||||
async def get_validation_attribution_layers(
|
||||
lookback: str = Query(default="30d"),
|
||||
horizon: str = Query(default="7d"),
|
||||
):
|
||||
"""Per-signal-layer (company, macro, competitive) performance metrics.
|
||||
|
||||
Returns average contribution percentage, dominant win rate, and
|
||||
dominant IC for each of the three signal layers, computed over
|
||||
the given lookback window and prediction horizon.
|
||||
|
||||
Requirement 12.6
|
||||
"""
|
||||
if lookback not in _VALID_LOOKBACKS:
|
||||
raise HTTPException(400, f"Invalid lookback: {lookback}. Must be one of {sorted(_VALID_LOOKBACKS)}")
|
||||
if horizon not in _VALID_HORIZONS:
|
||||
raise HTTPException(400, f"Invalid horizon: {horizon}. Must be one of {sorted(_VALID_HORIZONS)}")
|
||||
|
||||
lookback_days = _LOOKBACK_TO_DAYS[lookback]
|
||||
|
||||
try:
|
||||
results = await compute_layer_attribution(pool, lookback_days=lookback_days, horizon=horizon)
|
||||
except Exception:
|
||||
logger.exception("Failed to compute layer attribution")
|
||||
raise HTTPException(500, "Failed to compute layer attribution")
|
||||
|
||||
return {
|
||||
"layers": [asdict(r) for r in results],
|
||||
"lookback": lookback,
|
||||
"horizon": horizon,
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ from services.shared.schemas import (
|
||||
TrendSummary,
|
||||
TrendWindow,
|
||||
)
|
||||
from services.validation.prediction_snapshot import create_prediction_snapshot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -741,6 +742,92 @@ def _map_time_horizon_prefix(window: str) -> str:
|
||||
return mapping.get(window, "window_")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch evidence signals and docs for prediction snapshot (Requirement 1.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EVIDENCE_SIGNALS_QUERY = """
|
||||
SELECT
|
||||
dir.document_id::text AS document_id,
|
||||
di.id::text AS signal_id,
|
||||
dir.ticker,
|
||||
d.source_type AS source,
|
||||
d.source_type,
|
||||
dir.catalyst_type,
|
||||
dir.sentiment,
|
||||
dir.impact_score AS impact,
|
||||
di.confidence AS extraction_confidence,
|
||||
di.source_credibility AS weight
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.id = dir.intelligence_id
|
||||
JOIN documents d ON d.id = di.document_id
|
||||
WHERE dir.document_id = ANY($1::uuid[])
|
||||
AND di.validation_status = 'valid'
|
||||
"""
|
||||
|
||||
_EVIDENCE_DOCS_QUERY = """
|
||||
SELECT
|
||||
d.id::text AS document_id,
|
||||
COALESCE(d.title, '') AS title,
|
||||
COALESCE(d.url, '') AS url
|
||||
FROM documents d
|
||||
WHERE d.id = ANY($1::uuid[])
|
||||
"""
|
||||
|
||||
|
||||
async def _fetch_evidence_for_snapshot(
|
||||
pool: asyncpg.Pool,
|
||||
document_ids: list[str],
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Fetch evidence signals and document metadata for prediction snapshot.
|
||||
|
||||
Filters out non-UUID document IDs (e.g. synthetic pattern IDs) since
|
||||
they cannot be looked up in the documents table.
|
||||
|
||||
Returns (evidence_signals, evidence_docs).
|
||||
"""
|
||||
# Filter to valid UUIDs only
|
||||
valid_ids: list[str] = []
|
||||
for doc_id in document_ids:
|
||||
try:
|
||||
_uuid.UUID(doc_id)
|
||||
valid_ids.append(doc_id)
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
if not valid_ids:
|
||||
return [], []
|
||||
|
||||
signal_rows = await pool.fetch(_EVIDENCE_SIGNALS_QUERY, valid_ids)
|
||||
evidence_signals = [
|
||||
{
|
||||
"document_id": row["document_id"],
|
||||
"signal_id": row["signal_id"],
|
||||
"ticker": row["ticker"] or "",
|
||||
"source": row["source"] or "",
|
||||
"source_type": row["source_type"] or "",
|
||||
"catalyst_type": row["catalyst_type"] or "",
|
||||
"sentiment": row["sentiment"] or "",
|
||||
"impact": float(row["impact"] or 0.0),
|
||||
"extraction_confidence": float(row["extraction_confidence"] or 0.0),
|
||||
"weight": float(row["weight"] or 0.0),
|
||||
}
|
||||
for row in signal_rows
|
||||
]
|
||||
|
||||
doc_rows = await pool.fetch(_EVIDENCE_DOCS_QUERY, valid_ids)
|
||||
evidence_docs = [
|
||||
{
|
||||
"document_id": row["document_id"],
|
||||
"title": row["title"],
|
||||
"url": row["url"],
|
||||
}
|
||||
for row in doc_rows
|
||||
]
|
||||
|
||||
return evidence_signals, evidence_docs
|
||||
|
||||
|
||||
async def generate_recommendation(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
@@ -847,6 +934,22 @@ async def generate_recommendation(
|
||||
eligibility_result=result,
|
||||
)
|
||||
|
||||
# 7b. Capture prediction snapshot for model validation (Requirements 1.1, 1.6)
|
||||
try:
|
||||
all_doc_ids = list(summary.top_supporting_evidence) + list(summary.top_opposing_evidence)
|
||||
evidence_signals, evidence_docs = await _fetch_evidence_for_snapshot(
|
||||
pool, all_doc_ids,
|
||||
)
|
||||
await create_prediction_snapshot(
|
||||
pool, rec, summary, evidence_signals, evidence_docs,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to create prediction snapshot for %s/%s — recommendation "
|
||||
"persisted but snapshot creation failed",
|
||||
ticker, rec_id, exc_info=True,
|
||||
)
|
||||
|
||||
# 8. Publish prediction facts to analytical tables (Requirement 9.4)
|
||||
if minio_client is not None:
|
||||
try:
|
||||
|
||||
@@ -4,6 +4,10 @@ Task 32: Fetches historical recommendations from the database, simulates
|
||||
the decision logic chronologically using evaluate_recommendation(), tracks
|
||||
simulated positions and equity curve, and persists results to backtest_runs
|
||||
and backtest_trades tables.
|
||||
|
||||
Supports a validation mode (Requirements 15.1–15.5) that generates prediction
|
||||
snapshots and evaluates outcomes using only data available at each historical
|
||||
point in time, preventing future data leakage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -39,12 +43,22 @@ class BacktestReplay:
|
||||
self.pool = pool
|
||||
self._perf = PerformanceComputer()
|
||||
|
||||
async def run(self, config: BacktestConfig, backtest_id: str | None = None) -> BacktestResult:
|
||||
async def run(
|
||||
self,
|
||||
config: BacktestConfig,
|
||||
backtest_id: str | None = None,
|
||||
validation_mode: bool = False,
|
||||
) -> BacktestResult:
|
||||
"""Execute a full backtest replay.
|
||||
|
||||
Args:
|
||||
config: Backtest configuration (date range, capital, risk tier).
|
||||
backtest_id: Optional pre-generated ID. If not provided, one is generated.
|
||||
validation_mode: When True, creates prediction snapshots for each
|
||||
historical recommendation using only data available at that point
|
||||
in time, evaluates outcomes, and computes model metrics over the
|
||||
backtest period. Snapshots are tagged with the backtest_id.
|
||||
(Requirements 15.1–15.5)
|
||||
|
||||
Returns:
|
||||
BacktestResult with metrics, trade log, and equity curve.
|
||||
@@ -87,6 +101,7 @@ class BacktestReplay:
|
||||
daily_returns: list[float] = []
|
||||
prev_value = config.initial_capital
|
||||
trade_log: list[dict] = []
|
||||
validation_snapshot_ids: list[str] = [] # track snapshot IDs for validation mode
|
||||
|
||||
# Pre-load company sectors and latest prices for enrichment
|
||||
company_sectors: dict[str, str] = {}
|
||||
@@ -172,6 +187,25 @@ class BacktestReplay:
|
||||
now=sim_time,
|
||||
)
|
||||
|
||||
# --- Validation mode: create prediction snapshot (Req 15.1, 15.2, 15.4) ---
|
||||
if validation_mode and self.pool is not None:
|
||||
try:
|
||||
snapshot_id = await self._create_validation_snapshot(
|
||||
rec=rec,
|
||||
sim_time=sim_time,
|
||||
backtest_id=backtest_id,
|
||||
company_sectors=company_sectors,
|
||||
)
|
||||
if snapshot_id is not None:
|
||||
validation_snapshot_ids.append(snapshot_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Validation snapshot failed for %s at %s, continuing backtest",
|
||||
rec.get("ticker", "?"),
|
||||
sim_time,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if decision.decision == "act":
|
||||
act_count += 1
|
||||
ticker = decision.ticker
|
||||
@@ -348,6 +382,10 @@ class BacktestReplay:
|
||||
# Persist results
|
||||
await self._persist_results(result, closed_trades)
|
||||
|
||||
# --- Validation mode: evaluate outcomes and compute metrics (Req 15.3, 15.5) ---
|
||||
if validation_mode and self.pool is not None and validation_snapshot_ids:
|
||||
await self._run_validation_evaluation(backtest_id)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
@@ -356,6 +394,210 @@ class BacktestReplay:
|
||||
await self._persist_failed_run(backtest_id, config, str(exc))
|
||||
raise
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation mode helpers (Requirements 15.1–15.5)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# SQL to fetch the close price at or before a specific time — prevents
|
||||
# future data leakage by only returning data available at that point.
|
||||
_CLOSE_AT_TIME_SQL = """
|
||||
SELECT (data->>'c')::float AS close
|
||||
FROM market_snapshots
|
||||
WHERE ticker = $1
|
||||
AND snapshot_type = 'bar'
|
||||
AND data->>'c' IS NOT NULL
|
||||
AND captured_at <= $2
|
||||
ORDER BY captured_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
_COMPANY_SECTOR_SQL = """
|
||||
SELECT sector FROM companies WHERE ticker = $1 AND active = TRUE LIMIT 1
|
||||
"""
|
||||
|
||||
_SECTOR_ETF_MAP: dict[str, str] = {
|
||||
"Technology": "XLK",
|
||||
"Consumer Cyclical": "XLY",
|
||||
"Financial Services": "XLF",
|
||||
"Healthcare": "XLV",
|
||||
"Energy": "XLE",
|
||||
"Communication Services": "XLC",
|
||||
"Industrials": "XLI",
|
||||
"Consumer Defensive": "XLP",
|
||||
"Real Estate": "XLRE",
|
||||
"Utilities": "XLU",
|
||||
}
|
||||
|
||||
async def _fetch_close_at_time(
|
||||
self,
|
||||
ticker: str,
|
||||
target_time: datetime,
|
||||
) -> float | None:
|
||||
"""Fetch the close price for *ticker* at or before *target_time*.
|
||||
|
||||
Ensures no future data leakage — only market data with
|
||||
``captured_at <= target_time`` is considered (Requirement 15.4).
|
||||
"""
|
||||
if self.pool is None:
|
||||
return None
|
||||
row = await self.pool.fetchrow(self._CLOSE_AT_TIME_SQL, ticker, target_time)
|
||||
if row is None:
|
||||
return None
|
||||
return row["close"]
|
||||
|
||||
async def _create_validation_snapshot(
|
||||
self,
|
||||
rec: dict,
|
||||
sim_time: datetime,
|
||||
backtest_id: str,
|
||||
company_sectors: dict[str, str],
|
||||
) -> str | None:
|
||||
"""Create a prediction snapshot using only data available at *sim_time*.
|
||||
|
||||
Fetches ticker, SPY, and sector ETF prices as of *sim_time* to prevent
|
||||
future data leakage (Requirements 15.1, 15.2, 15.4). The snapshot is
|
||||
tagged with *backtest_id* in its metadata field (Requirement 15.5).
|
||||
|
||||
Returns the snapshot UUID on success, or ``None`` on failure.
|
||||
"""
|
||||
from services.validation.prediction_snapshot import (
|
||||
SECTOR_ETF_MAP,
|
||||
)
|
||||
|
||||
ticker = rec.get("ticker", "")
|
||||
if not ticker:
|
||||
return None
|
||||
|
||||
# Fetch prices using only data available at sim_time (Req 15.4)
|
||||
ticker_price = await self._fetch_close_at_time(ticker, sim_time)
|
||||
spy_price = await self._fetch_close_at_time("SPY", sim_time)
|
||||
|
||||
# Sector ETF price
|
||||
sector = company_sectors.get(ticker)
|
||||
sector_etf_ticker = SECTOR_ETF_MAP.get(sector) if sector else None
|
||||
sector_etf_price: float | None = None
|
||||
if sector_etf_ticker is not None:
|
||||
sector_etf_price = await self._fetch_close_at_time(
|
||||
sector_etf_ticker, sim_time
|
||||
)
|
||||
|
||||
snapshot_id = str(uuid.uuid4())
|
||||
|
||||
# Build metadata tagged with backtest_id (Req 15.5)
|
||||
metadata: dict = {
|
||||
"backtest_id": backtest_id,
|
||||
"source": "backtest_validation",
|
||||
}
|
||||
|
||||
# Map recommendation fields to snapshot columns
|
||||
direction = rec.get("direction", rec.get("trend_direction", "neutral"))
|
||||
action = rec.get("action", "watch")
|
||||
mode = rec.get("mode", "informational")
|
||||
confidence = float(rec.get("confidence", 0.5))
|
||||
strength = float(rec.get("strength", rec.get("trend_strength", 0.5)))
|
||||
contradiction = float(rec.get("contradiction", rec.get("contradiction_score", 0.0)))
|
||||
p_bull = rec.get("p_bull")
|
||||
if p_bull is not None:
|
||||
p_bull = float(p_bull)
|
||||
p_bear = (1.0 - p_bull) if p_bull is not None else None
|
||||
window = rec.get("window", rec.get("trend_window", "7d"))
|
||||
horizon = rec.get("time_horizon", rec.get("horizon", "7d"))
|
||||
|
||||
# Insert the snapshot directly — we bypass create_prediction_snapshot()
|
||||
# because that function fetches *latest* prices (not point-in-time).
|
||||
insert_sql = """
|
||||
INSERT INTO prediction_snapshots (
|
||||
id, generated_at, ticker, window, horizon, direction, action, mode,
|
||||
strength, confidence, contradiction, p_bull, p_bear,
|
||||
score_company, score_macro, score_competitive,
|
||||
evidence_count, unique_source_count, duplicate_evidence_count,
|
||||
price_at_prediction, spy_price_at_prediction,
|
||||
sector_etf_price_at_prediction, metadata
|
||||
) VALUES (
|
||||
$1::uuid, $2, $3, $4, $5, $6, $7, $8,
|
||||
$9, $10, $11, $12, $13,
|
||||
$14, $15, $16,
|
||||
$17, $18, $19,
|
||||
$20, $21, $22,
|
||||
$23::jsonb
|
||||
)
|
||||
"""
|
||||
await self.pool.execute(
|
||||
insert_sql,
|
||||
snapshot_id,
|
||||
sim_time,
|
||||
ticker,
|
||||
str(window),
|
||||
str(horizon),
|
||||
str(direction),
|
||||
str(action),
|
||||
str(mode),
|
||||
strength,
|
||||
confidence,
|
||||
contradiction,
|
||||
p_bull,
|
||||
p_bear,
|
||||
float(rec.get("score_company", 0.0)),
|
||||
float(rec.get("score_macro", 0.0)),
|
||||
float(rec.get("score_competitive", 0.0)),
|
||||
int(rec.get("evidence_count", 0)),
|
||||
int(rec.get("unique_source_count", 0)),
|
||||
int(rec.get("duplicate_evidence_count", 0)),
|
||||
ticker_price,
|
||||
spy_price,
|
||||
sector_etf_price,
|
||||
json.dumps(metadata),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Validation snapshot %s created for %s at %s (backtest %s)",
|
||||
snapshot_id,
|
||||
ticker,
|
||||
sim_time,
|
||||
backtest_id,
|
||||
)
|
||||
return snapshot_id
|
||||
|
||||
async def _run_validation_evaluation(self, backtest_id: str) -> None:
|
||||
"""Evaluate prediction outcomes and compute metrics for the backtest.
|
||||
|
||||
Calls the outcome evaluator and metrics engine after the backtest
|
||||
completes (Requirements 15.3, 15.5). Failures are logged but do
|
||||
not block the backtest result.
|
||||
"""
|
||||
from services.validation.metrics import compute_and_store_metric_snapshots
|
||||
from services.validation.outcome_evaluator import evaluate_matured_predictions
|
||||
|
||||
# Step 1: Evaluate matured predictions (Req 15.3)
|
||||
try:
|
||||
outcomes_count = await evaluate_matured_predictions(self.pool)
|
||||
logger.info(
|
||||
"Backtest %s validation: %d prediction outcomes evaluated",
|
||||
backtest_id,
|
||||
outcomes_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Backtest %s: outcome evaluation failed, continuing",
|
||||
backtest_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Step 2: Compute and store metric snapshots (Req 15.5)
|
||||
try:
|
||||
snapshots = await compute_and_store_metric_snapshots(self.pool)
|
||||
logger.info(
|
||||
"Backtest %s validation: %d metric snapshots computed",
|
||||
backtest_id,
|
||||
len(snapshots),
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Backtest %s: metric snapshot computation failed, continuing",
|
||||
backtest_id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Database helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
"""Quality gate for live trading eligibility.
|
||||
|
||||
Evaluates aggregate model metrics against configurable thresholds and
|
||||
determines whether the system meets minimum quality standards for live
|
||||
trading. When any threshold is not met, the gate forces all
|
||||
recommendations to paper mode (fail-safe).
|
||||
|
||||
Requirements: 11.1, 11.2, 11.3, 11.4, 11.5, 11.6, 11.7
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger("trading_engine.quality_gate")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityGateConfig:
|
||||
"""Configurable thresholds for live trading eligibility."""
|
||||
|
||||
min_prediction_count: int = 100
|
||||
min_ic: float = 0.03
|
||||
min_win_rate: float = 0.53
|
||||
max_ece: float = 0.15
|
||||
min_excess_return_vs_spy: float = 0.0
|
||||
max_snapshot_age_hours: int = 24
|
||||
|
||||
|
||||
@dataclass
|
||||
class GateThresholdResult:
|
||||
"""Result for a single threshold check."""
|
||||
|
||||
name: str
|
||||
threshold: float
|
||||
actual: float
|
||||
passed: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityGateResult:
|
||||
"""Full gate evaluation result."""
|
||||
|
||||
passed: bool
|
||||
evaluated_at: datetime
|
||||
threshold_results: list[GateThresholdResult] = field(default_factory=list)
|
||||
reason: str = ""
|
||||
snapshot_id: str | None = None
|
||||
config: QualityGateConfig = field(default_factory=QualityGateConfig)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Threshold evaluation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _evaluate_thresholds(
|
||||
snapshot: dict,
|
||||
config: QualityGateConfig,
|
||||
) -> list[GateThresholdResult]:
|
||||
"""Evaluate each threshold against snapshot metric values."""
|
||||
results: list[GateThresholdResult] = []
|
||||
|
||||
# min_prediction_count
|
||||
actual_count = snapshot.get("prediction_count") or 0
|
||||
results.append(
|
||||
GateThresholdResult(
|
||||
name="min_prediction_count",
|
||||
threshold=float(config.min_prediction_count),
|
||||
actual=float(actual_count),
|
||||
passed=actual_count >= config.min_prediction_count,
|
||||
)
|
||||
)
|
||||
|
||||
# min_ic
|
||||
actual_ic = snapshot.get("information_coefficient")
|
||||
if actual_ic is None:
|
||||
actual_ic = 0.0
|
||||
results.append(
|
||||
GateThresholdResult(
|
||||
name="min_ic",
|
||||
threshold=config.min_ic,
|
||||
actual=float(actual_ic),
|
||||
passed=float(actual_ic) >= config.min_ic,
|
||||
)
|
||||
)
|
||||
|
||||
# min_win_rate
|
||||
actual_wr = snapshot.get("win_rate")
|
||||
if actual_wr is None:
|
||||
actual_wr = 0.0
|
||||
results.append(
|
||||
GateThresholdResult(
|
||||
name="min_win_rate",
|
||||
threshold=config.min_win_rate,
|
||||
actual=float(actual_wr),
|
||||
passed=float(actual_wr) >= config.min_win_rate,
|
||||
)
|
||||
)
|
||||
|
||||
# max_ece (calibration_error)
|
||||
actual_ece = snapshot.get("calibration_error")
|
||||
if actual_ece is None:
|
||||
actual_ece = 1.0 # worst-case when missing
|
||||
results.append(
|
||||
GateThresholdResult(
|
||||
name="max_ece",
|
||||
threshold=config.max_ece,
|
||||
actual=float(actual_ece),
|
||||
passed=float(actual_ece) <= config.max_ece,
|
||||
)
|
||||
)
|
||||
|
||||
# min_excess_return_vs_spy
|
||||
actual_excess = snapshot.get("avg_excess_return_vs_spy")
|
||||
if actual_excess is None:
|
||||
actual_excess = 0.0
|
||||
results.append(
|
||||
GateThresholdResult(
|
||||
name="min_excess_return_vs_spy",
|
||||
threshold=config.min_excess_return_vs_spy,
|
||||
actual=float(actual_excess),
|
||||
passed=float(actual_excess) >= config.min_excess_return_vs_spy,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def evaluate_quality_gate(
|
||||
pool: asyncpg.Pool,
|
||||
config: QualityGateConfig | None = None,
|
||||
) -> QualityGateResult:
|
||||
"""Evaluate model quality gate from latest metric snapshot.
|
||||
|
||||
Reads the most recent ``model_metric_snapshot`` for the 30d lookback
|
||||
and 7d horizon (the primary evaluation window).
|
||||
|
||||
If no snapshot exists or snapshot is stale (>max_snapshot_age_hours),
|
||||
defaults to paper-only mode (fail-safe).
|
||||
|
||||
Stores result in ``risk_configs`` under ``'model_quality_gate'`` key.
|
||||
"""
|
||||
if config is None:
|
||||
config = await load_gate_config_from_db(pool)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Fetch the most recent metric snapshot for 30d lookback / 7d horizon
|
||||
try:
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT id, generated_at, prediction_count, win_rate,
|
||||
directional_accuracy, information_coefficient,
|
||||
rank_information_coefficient, avg_return,
|
||||
avg_excess_return_vs_spy, avg_excess_return_vs_sector,
|
||||
calibration_error, brier_score,
|
||||
buy_win_rate, sell_win_rate, hold_win_rate
|
||||
FROM model_metric_snapshots
|
||||
WHERE lookback_window = '30d' AND horizon = '7d'
|
||||
ORDER BY generated_at DESC
|
||||
LIMIT 1""",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to query model_metric_snapshots")
|
||||
row = None
|
||||
|
||||
# Fail-safe: no snapshot exists
|
||||
if row is None:
|
||||
result = QualityGateResult(
|
||||
passed=False,
|
||||
evaluated_at=now,
|
||||
threshold_results=[],
|
||||
reason="no model metric snapshot available — defaulting to paper-only",
|
||||
snapshot_id=None,
|
||||
config=config,
|
||||
)
|
||||
logger.warning("Quality gate: %s", result.reason)
|
||||
await _store_gate_result(pool, result)
|
||||
return result
|
||||
|
||||
snapshot = dict(row)
|
||||
snapshot_id = str(snapshot["id"])
|
||||
generated_at = snapshot["generated_at"]
|
||||
|
||||
# Fail-safe: stale snapshot
|
||||
age_hours = (now - generated_at).total_seconds() / 3600.0
|
||||
if age_hours > config.max_snapshot_age_hours:
|
||||
result = QualityGateResult(
|
||||
passed=False,
|
||||
evaluated_at=now,
|
||||
threshold_results=[],
|
||||
reason=(
|
||||
f"most recent snapshot is {age_hours:.1f}h old "
|
||||
f"(max {config.max_snapshot_age_hours}h) — defaulting to paper-only"
|
||||
),
|
||||
snapshot_id=snapshot_id,
|
||||
config=config,
|
||||
)
|
||||
logger.warning("Quality gate: %s", result.reason)
|
||||
await _store_gate_result(pool, result)
|
||||
return result
|
||||
|
||||
# Evaluate thresholds
|
||||
threshold_results = _evaluate_thresholds(snapshot, config)
|
||||
failed = [r for r in threshold_results if not r.passed]
|
||||
|
||||
if failed:
|
||||
failed_names = ", ".join(
|
||||
f"{r.name}(actual={r.actual:.4f}, threshold={r.threshold:.4f})"
|
||||
for r in failed
|
||||
)
|
||||
reason = f"failed: {failed_names}"
|
||||
passed = False
|
||||
else:
|
||||
reason = "all thresholds met"
|
||||
passed = True
|
||||
|
||||
result = QualityGateResult(
|
||||
passed=passed,
|
||||
evaluated_at=now,
|
||||
threshold_results=threshold_results,
|
||||
reason=reason,
|
||||
snapshot_id=snapshot_id,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Log details
|
||||
for tr in threshold_results:
|
||||
logger.info(
|
||||
"Quality gate threshold %s: actual=%.4f threshold=%.4f %s",
|
||||
tr.name,
|
||||
tr.actual,
|
||||
tr.threshold,
|
||||
"PASS" if tr.passed else "FAIL",
|
||||
)
|
||||
logger.info("Quality gate result: %s — %s", "PASS" if passed else "FAIL", reason)
|
||||
|
||||
await _store_gate_result(pool, result)
|
||||
return result
|
||||
|
||||
|
||||
async def load_gate_config_from_db(
|
||||
pool: asyncpg.Pool,
|
||||
) -> QualityGateConfig:
|
||||
"""Load gate thresholds from risk_configs, with defaults.
|
||||
|
||||
Looks for a ``risk_configs`` row with ``name = 'model_quality_gate_config'``.
|
||||
If found, merges stored thresholds over the defaults. If not found or
|
||||
the stored JSON is invalid, returns the default config.
|
||||
"""
|
||||
defaults = QualityGateConfig()
|
||||
try:
|
||||
row = await pool.fetchrow(
|
||||
"SELECT config FROM risk_configs WHERE name = 'model_quality_gate_config'",
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to load gate config from risk_configs — using defaults")
|
||||
return defaults
|
||||
|
||||
if row is None:
|
||||
return defaults
|
||||
|
||||
try:
|
||||
raw = row["config"]
|
||||
cfg = raw if isinstance(raw, dict) else json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning("Invalid gate config JSON in risk_configs — using defaults")
|
||||
return defaults
|
||||
|
||||
return QualityGateConfig(
|
||||
min_prediction_count=int(cfg.get("min_prediction_count", defaults.min_prediction_count)),
|
||||
min_ic=float(cfg.get("min_ic", defaults.min_ic)),
|
||||
min_win_rate=float(cfg.get("min_win_rate", defaults.min_win_rate)),
|
||||
max_ece=float(cfg.get("max_ece", defaults.max_ece)),
|
||||
min_excess_return_vs_spy=float(
|
||||
cfg.get("min_excess_return_vs_spy", defaults.min_excess_return_vs_spy)
|
||||
),
|
||||
max_snapshot_age_hours=int(
|
||||
cfg.get("max_snapshot_age_hours", defaults.max_snapshot_age_hours)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _gate_result_to_json(result: QualityGateResult) -> str:
|
||||
"""Serialize a QualityGateResult to JSON for storage in risk_configs."""
|
||||
payload = {
|
||||
"passed": result.passed,
|
||||
"evaluated_at": result.evaluated_at.isoformat(),
|
||||
"reason": result.reason,
|
||||
"snapshot_id": result.snapshot_id,
|
||||
"config": asdict(result.config),
|
||||
"threshold_results": [asdict(tr) for tr in result.threshold_results],
|
||||
}
|
||||
return json.dumps(payload, default=str)
|
||||
|
||||
|
||||
async def _store_gate_result(pool: asyncpg.Pool, result: QualityGateResult) -> None:
|
||||
"""Upsert gate evaluation result into risk_configs."""
|
||||
payload = _gate_result_to_json(result)
|
||||
try:
|
||||
await pool.execute(
|
||||
"""INSERT INTO risk_configs (name, config, updated_at)
|
||||
VALUES ('model_quality_gate', $1::jsonb, NOW())
|
||||
ON CONFLICT (name) WHERE active = TRUE
|
||||
DO UPDATE SET config = $1::jsonb, updated_at = NOW()""",
|
||||
payload,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to store quality gate result in risk_configs")
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,591 @@
|
||||
"""Attribution Engine — per-source, per-catalyst, and per-layer performance.
|
||||
|
||||
Joins signal evidence links with prediction outcomes to compute attribution
|
||||
metrics that identify which sources, catalyst types, and signal layers
|
||||
contribute most to accurate predictions.
|
||||
|
||||
Requirements: 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceAttribution:
|
||||
"""Performance metrics for a single source."""
|
||||
|
||||
source: str
|
||||
source_type: str
|
||||
prediction_count: int
|
||||
avg_weight: float
|
||||
avg_contribution_score: float
|
||||
win_rate: float
|
||||
avg_future_return: float
|
||||
avg_excess_return_vs_spy: float
|
||||
information_coefficient: float | None
|
||||
duplicate_rate: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class CatalystAttribution:
|
||||
"""Performance metrics for a single catalyst type."""
|
||||
|
||||
catalyst_type: str
|
||||
prediction_count: int
|
||||
win_rate: float
|
||||
avg_future_return: float
|
||||
avg_excess_return_vs_spy: float
|
||||
information_coefficient: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerAttribution:
|
||||
"""Performance metrics for a signal layer."""
|
||||
|
||||
layer: str # company, macro, competitive
|
||||
avg_contribution_pct: float
|
||||
dominant_win_rate: float # win rate when this layer > 30% contribution
|
||||
dominant_ic: float | None # IC when this layer > 30% contribution
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure computation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _pearson_correlation(xs: list[float], ys: list[float]) -> float | None:
|
||||
"""Compute Pearson correlation coefficient between two lists.
|
||||
|
||||
Returns None if the lists have fewer than 2 elements or if either
|
||||
has zero variance. Guards against NaN/infinity.
|
||||
"""
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return None
|
||||
|
||||
mean_x = sum(xs) / n
|
||||
mean_y = sum(ys) / n
|
||||
|
||||
cov = 0.0
|
||||
var_x = 0.0
|
||||
var_y = 0.0
|
||||
|
||||
for x, y in zip(xs, ys):
|
||||
dx = x - mean_x
|
||||
dy = y - mean_y
|
||||
cov += dx * dy
|
||||
var_x += dx * dx
|
||||
var_y += dy * dy
|
||||
|
||||
if var_x == 0.0 or var_y == 0.0:
|
||||
return None
|
||||
|
||||
r = cov / math.sqrt(var_x * var_y)
|
||||
|
||||
if math.isnan(r) or math.isinf(r):
|
||||
return None
|
||||
|
||||
return max(-1.0, min(1.0, r))
|
||||
|
||||
|
||||
def _compute_ic(
|
||||
contribution_scores: list[float],
|
||||
future_returns: list[float],
|
||||
) -> float | None:
|
||||
"""Compute IC (Pearson correlation) between contribution scores and returns.
|
||||
|
||||
Returns None when fewer than 30 data points.
|
||||
"""
|
||||
if len(contribution_scores) < 30 or len(future_returns) < 30:
|
||||
return None
|
||||
|
||||
n = min(len(contribution_scores), len(future_returns))
|
||||
return _pearson_correlation(contribution_scores[:n], future_returns[:n])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL queries — source attribution via v_source_performance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SOURCE_ATTRIBUTION_SQL = """
|
||||
SELECT
|
||||
source,
|
||||
source_type,
|
||||
weight,
|
||||
contribution_score,
|
||||
is_duplicate,
|
||||
direction_correct,
|
||||
future_return,
|
||||
excess_return_vs_spy
|
||||
FROM v_source_performance
|
||||
WHERE horizon = $1
|
||||
AND generated_at >= $2
|
||||
"""
|
||||
|
||||
_SOURCE_ATTRIBUTION_ALL_SQL = """
|
||||
SELECT
|
||||
source,
|
||||
source_type,
|
||||
weight,
|
||||
contribution_score,
|
||||
is_duplicate,
|
||||
direction_correct,
|
||||
future_return,
|
||||
excess_return_vs_spy
|
||||
FROM v_source_performance
|
||||
WHERE horizon = $1
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL queries — catalyst attribution via v_source_performance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CATALYST_ATTRIBUTION_SQL = """
|
||||
SELECT
|
||||
catalyst_type,
|
||||
weight,
|
||||
contribution_score,
|
||||
direction_correct,
|
||||
future_return,
|
||||
excess_return_vs_spy
|
||||
FROM v_source_performance
|
||||
WHERE horizon = $1
|
||||
AND generated_at >= $2
|
||||
"""
|
||||
|
||||
_CATALYST_ATTRIBUTION_ALL_SQL = """
|
||||
SELECT
|
||||
catalyst_type,
|
||||
weight,
|
||||
contribution_score,
|
||||
direction_correct,
|
||||
future_return,
|
||||
excess_return_vs_spy
|
||||
FROM v_source_performance
|
||||
WHERE horizon = $1
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL queries — layer attribution via prediction_snapshots + outcomes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LAYER_ATTRIBUTION_SQL = """
|
||||
SELECT
|
||||
ps.score_company,
|
||||
ps.score_macro,
|
||||
ps.score_competitive,
|
||||
po.direction_correct,
|
||||
po.future_return
|
||||
FROM prediction_snapshots ps
|
||||
JOIN prediction_outcomes po ON po.prediction_id = ps.id
|
||||
WHERE po.horizon = $1
|
||||
AND ps.generated_at >= $2
|
||||
"""
|
||||
|
||||
_LAYER_ATTRIBUTION_ALL_SQL = """
|
||||
SELECT
|
||||
ps.score_company,
|
||||
ps.score_macro,
|
||||
ps.score_competitive,
|
||||
po.direction_correct,
|
||||
po.future_return
|
||||
FROM prediction_snapshots ps
|
||||
JOIN prediction_outcomes po ON po.prediction_id = ps.id
|
||||
WHERE po.horizon = $1
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source attribution (Requirements 7.1, 7.2, 7.7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def compute_source_attribution(
|
||||
pool: asyncpg.Pool,
|
||||
lookback_days: int = 30,
|
||||
horizon: str = "7d",
|
||||
) -> list[SourceAttribution]:
|
||||
"""Compute per-source performance metrics.
|
||||
|
||||
Queries v_source_performance, groups by source, and computes:
|
||||
prediction count, avg weight, avg contribution score, win rate,
|
||||
avg future return, avg excess return vs SPY, IC, and duplicate rate.
|
||||
|
||||
Returns a list of SourceAttribution sorted by prediction count descending.
|
||||
"""
|
||||
now = datetime.now().astimezone()
|
||||
cutoff = now - timedelta(days=lookback_days)
|
||||
|
||||
try:
|
||||
rows = await pool.fetch(_SOURCE_ATTRIBUTION_SQL, horizon, cutoff)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to query source attribution for horizon=%s lookback=%dd",
|
||||
horizon,
|
||||
lookback_days,
|
||||
)
|
||||
return []
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Group rows by source
|
||||
source_groups: dict[str, list[dict]] = {}
|
||||
for row in rows:
|
||||
r = dict(row)
|
||||
key = r.get("source") or "unknown"
|
||||
source_groups.setdefault(key, []).append(r)
|
||||
|
||||
results: list[SourceAttribution] = []
|
||||
|
||||
for source, group in source_groups.items():
|
||||
count = len(group)
|
||||
|
||||
# Source type — take the most common one
|
||||
source_type = group[0].get("source_type") or "unknown"
|
||||
|
||||
# Avg weight
|
||||
weights = [r["weight"] for r in group if r.get("weight") is not None]
|
||||
avg_weight = sum(weights) / len(weights) if weights else 0.0
|
||||
|
||||
# Avg contribution score
|
||||
contrib_scores = [
|
||||
r["contribution_score"]
|
||||
for r in group
|
||||
if r.get("contribution_score") is not None
|
||||
]
|
||||
avg_contribution_score = (
|
||||
sum(contrib_scores) / len(contrib_scores) if contrib_scores else 0.0
|
||||
)
|
||||
|
||||
# Win rate
|
||||
direction_rows = [r for r in group if r.get("direction_correct") is not None]
|
||||
win_count = sum(1 for r in direction_rows if r["direction_correct"] is True)
|
||||
win_rate = win_count / len(direction_rows) if direction_rows else 0.0
|
||||
|
||||
# Avg future return
|
||||
returns = [
|
||||
r["future_return"] for r in group if r.get("future_return") is not None
|
||||
]
|
||||
avg_future_return = sum(returns) / len(returns) if returns else 0.0
|
||||
|
||||
# Avg excess return vs SPY
|
||||
excess_returns = [
|
||||
r["excess_return_vs_spy"]
|
||||
for r in group
|
||||
if r.get("excess_return_vs_spy") is not None
|
||||
]
|
||||
avg_excess_return_vs_spy = (
|
||||
sum(excess_returns) / len(excess_returns) if excess_returns else 0.0
|
||||
)
|
||||
|
||||
# IC: correlation between contribution scores and future returns
|
||||
ic_scores = [
|
||||
r["contribution_score"]
|
||||
for r in group
|
||||
if r.get("contribution_score") is not None
|
||||
and r.get("future_return") is not None
|
||||
]
|
||||
ic_returns = [
|
||||
r["future_return"]
|
||||
for r in group
|
||||
if r.get("contribution_score") is not None
|
||||
and r.get("future_return") is not None
|
||||
]
|
||||
ic = _compute_ic(ic_scores, ic_returns)
|
||||
|
||||
# Duplicate rate: is_duplicate=true / total
|
||||
dup_count = sum(1 for r in group if r.get("is_duplicate") is True)
|
||||
duplicate_rate = dup_count / count
|
||||
|
||||
results.append(
|
||||
SourceAttribution(
|
||||
source=source,
|
||||
source_type=source_type,
|
||||
prediction_count=count,
|
||||
avg_weight=avg_weight,
|
||||
avg_contribution_score=avg_contribution_score,
|
||||
win_rate=win_rate,
|
||||
avg_future_return=avg_future_return,
|
||||
avg_excess_return_vs_spy=avg_excess_return_vs_spy,
|
||||
information_coefficient=ic,
|
||||
duplicate_rate=duplicate_rate,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by prediction count descending
|
||||
results.sort(key=lambda a: a.prediction_count, reverse=True)
|
||||
|
||||
logger.info(
|
||||
"Computed source attribution for %d sources (horizon=%s, lookback=%dd)",
|
||||
len(results),
|
||||
horizon,
|
||||
lookback_days,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Catalyst attribution (Requirements 7.3, 7.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def compute_catalyst_attribution(
|
||||
pool: asyncpg.Pool,
|
||||
lookback_days: int = 30,
|
||||
horizon: str = "7d",
|
||||
) -> list[CatalystAttribution]:
|
||||
"""Compute per-catalyst-type performance metrics.
|
||||
|
||||
Queries v_source_performance, groups by catalyst_type, and computes:
|
||||
prediction count, win rate, avg future return, avg excess return vs SPY,
|
||||
and IC.
|
||||
|
||||
Returns a list of CatalystAttribution sorted by prediction count descending.
|
||||
"""
|
||||
now = datetime.now().astimezone()
|
||||
cutoff = now - timedelta(days=lookback_days)
|
||||
|
||||
try:
|
||||
rows = await pool.fetch(_CATALYST_ATTRIBUTION_SQL, horizon, cutoff)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to query catalyst attribution for horizon=%s lookback=%dd",
|
||||
horizon,
|
||||
lookback_days,
|
||||
)
|
||||
return []
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
||||
# Group rows by catalyst_type
|
||||
catalyst_groups: dict[str, list[dict]] = {}
|
||||
for row in rows:
|
||||
r = dict(row)
|
||||
key = r.get("catalyst_type") or "unknown"
|
||||
catalyst_groups.setdefault(key, []).append(r)
|
||||
|
||||
results: list[CatalystAttribution] = []
|
||||
|
||||
for catalyst_type, group in catalyst_groups.items():
|
||||
count = len(group)
|
||||
|
||||
# Win rate
|
||||
direction_rows = [r for r in group if r.get("direction_correct") is not None]
|
||||
win_count = sum(1 for r in direction_rows if r["direction_correct"] is True)
|
||||
win_rate = win_count / len(direction_rows) if direction_rows else 0.0
|
||||
|
||||
# Avg future return
|
||||
returns = [
|
||||
r["future_return"] for r in group if r.get("future_return") is not None
|
||||
]
|
||||
avg_future_return = sum(returns) / len(returns) if returns else 0.0
|
||||
|
||||
# Avg excess return vs SPY
|
||||
excess_returns = [
|
||||
r["excess_return_vs_spy"]
|
||||
for r in group
|
||||
if r.get("excess_return_vs_spy") is not None
|
||||
]
|
||||
avg_excess_return_vs_spy = (
|
||||
sum(excess_returns) / len(excess_returns) if excess_returns else 0.0
|
||||
)
|
||||
|
||||
# IC: correlation between contribution scores and future returns
|
||||
ic_scores = [
|
||||
r["contribution_score"]
|
||||
for r in group
|
||||
if r.get("contribution_score") is not None
|
||||
and r.get("future_return") is not None
|
||||
]
|
||||
ic_returns = [
|
||||
r["future_return"]
|
||||
for r in group
|
||||
if r.get("contribution_score") is not None
|
||||
and r.get("future_return") is not None
|
||||
]
|
||||
ic = _compute_ic(ic_scores, ic_returns)
|
||||
|
||||
results.append(
|
||||
CatalystAttribution(
|
||||
catalyst_type=catalyst_type,
|
||||
prediction_count=count,
|
||||
win_rate=win_rate,
|
||||
avg_future_return=avg_future_return,
|
||||
avg_excess_return_vs_spy=avg_excess_return_vs_spy,
|
||||
information_coefficient=ic,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by prediction count descending
|
||||
results.sort(key=lambda a: a.prediction_count, reverse=True)
|
||||
|
||||
logger.info(
|
||||
"Computed catalyst attribution for %d catalyst types "
|
||||
"(horizon=%s, lookback=%dd)",
|
||||
len(results),
|
||||
horizon,
|
||||
lookback_days,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer attribution (Requirements 7.5, 7.6)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def compute_layer_attribution(
|
||||
pool: asyncpg.Pool,
|
||||
lookback_days: int = 30,
|
||||
horizon: str = "7d",
|
||||
) -> list[LayerAttribution]:
|
||||
"""Compute per-layer (company, macro, competitive) performance metrics.
|
||||
|
||||
Queries prediction_snapshots joined with prediction_outcomes to get
|
||||
score_company, score_macro, score_competitive alongside outcomes.
|
||||
|
||||
For each layer computes:
|
||||
- avg_contribution_pct: average of layer_score / total_score across all
|
||||
predictions (where total_score > 0)
|
||||
- dominant_win_rate: win rate for predictions where the layer contributes
|
||||
more than 30% of the total score
|
||||
- dominant_ic: IC (Pearson correlation between layer score and future
|
||||
return) for predictions where the layer contributes > 30%
|
||||
|
||||
Returns a list of 3 LayerAttribution objects (company, macro, competitive).
|
||||
"""
|
||||
now = datetime.now().astimezone()
|
||||
cutoff = now - timedelta(days=lookback_days)
|
||||
|
||||
try:
|
||||
rows = await pool.fetch(_LAYER_ATTRIBUTION_SQL, horizon, cutoff)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to query layer attribution for horizon=%s lookback=%dd",
|
||||
horizon,
|
||||
lookback_days,
|
||||
)
|
||||
return []
|
||||
|
||||
if not rows:
|
||||
return [
|
||||
LayerAttribution(
|
||||
layer="company",
|
||||
avg_contribution_pct=0.0,
|
||||
dominant_win_rate=0.0,
|
||||
dominant_ic=None,
|
||||
),
|
||||
LayerAttribution(
|
||||
layer="macro",
|
||||
avg_contribution_pct=0.0,
|
||||
dominant_win_rate=0.0,
|
||||
dominant_ic=None,
|
||||
),
|
||||
LayerAttribution(
|
||||
layer="competitive",
|
||||
avg_contribution_pct=0.0,
|
||||
dominant_win_rate=0.0,
|
||||
dominant_ic=None,
|
||||
),
|
||||
]
|
||||
|
||||
row_dicts = [dict(r) for r in rows]
|
||||
|
||||
layers = [
|
||||
("company", "score_company"),
|
||||
("macro", "score_macro"),
|
||||
("competitive", "score_competitive"),
|
||||
]
|
||||
|
||||
results: list[LayerAttribution] = []
|
||||
|
||||
for layer_name, score_field in layers:
|
||||
# --- Average contribution percentage ---
|
||||
contribution_pcts: list[float] = []
|
||||
for r in row_dicts:
|
||||
total = (
|
||||
(r.get("score_company") or 0.0)
|
||||
+ (r.get("score_macro") or 0.0)
|
||||
+ (r.get("score_competitive") or 0.0)
|
||||
)
|
||||
if total > 0.0:
|
||||
layer_score = r.get(score_field) or 0.0
|
||||
contribution_pcts.append(layer_score / total)
|
||||
|
||||
avg_contribution_pct = (
|
||||
sum(contribution_pcts) / len(contribution_pcts)
|
||||
if contribution_pcts
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# --- Dominant predictions: layer > 30% of total score ---
|
||||
dominant_rows: list[dict] = []
|
||||
for r in row_dicts:
|
||||
total = (
|
||||
(r.get("score_company") or 0.0)
|
||||
+ (r.get("score_macro") or 0.0)
|
||||
+ (r.get("score_competitive") or 0.0)
|
||||
)
|
||||
if total > 0.0:
|
||||
layer_score = r.get(score_field) or 0.0
|
||||
if layer_score / total > 0.30:
|
||||
dominant_rows.append(r)
|
||||
|
||||
# Dominant win rate
|
||||
dominant_direction_rows = [
|
||||
r for r in dominant_rows if r.get("direction_correct") is not None
|
||||
]
|
||||
dominant_win_count = sum(
|
||||
1 for r in dominant_direction_rows if r["direction_correct"] is True
|
||||
)
|
||||
dominant_win_rate = (
|
||||
dominant_win_count / len(dominant_direction_rows)
|
||||
if dominant_direction_rows
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# Dominant IC: correlation between layer score and future return
|
||||
dom_scores = [
|
||||
r.get(score_field) or 0.0
|
||||
for r in dominant_rows
|
||||
if r.get("future_return") is not None
|
||||
]
|
||||
dom_returns = [
|
||||
r["future_return"]
|
||||
for r in dominant_rows
|
||||
if r.get("future_return") is not None
|
||||
]
|
||||
dominant_ic = _compute_ic(dom_scores, dom_returns)
|
||||
|
||||
results.append(
|
||||
LayerAttribution(
|
||||
layer=layer_name,
|
||||
avg_contribution_pct=avg_contribution_pct,
|
||||
dominant_win_rate=dominant_win_rate,
|
||||
dominant_ic=dominant_ic,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Computed layer attribution for 3 layers (horizon=%s, lookback=%dd)",
|
||||
horizon,
|
||||
lookback_days,
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Calibration Engine — Bayesian shrinkage source reliability and weight adjustment.
|
||||
|
||||
Computes source reliability scores using Bayesian shrinkage from historical
|
||||
prediction outcomes, and adjusts evidence weights based on source performance.
|
||||
Updates the existing source_accuracy table with reliability scores.
|
||||
|
||||
Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure functions — testable without a database
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_source_reliability(
|
||||
observed_win_rate: float,
|
||||
sample_count: int,
|
||||
prior_strength: int = 30,
|
||||
) -> float:
|
||||
"""Bayesian shrinkage source reliability.
|
||||
|
||||
reliability = 0.5 + (n / (n + prior_strength)) * (observed_win_rate - 0.5)
|
||||
|
||||
Returns value in [0.0, 1.0].
|
||||
When n=0, returns 0.5 (prior mean).
|
||||
As n→∞, approaches observed_win_rate.
|
||||
"""
|
||||
if sample_count <= 0:
|
||||
return 0.5
|
||||
|
||||
shrinkage = sample_count / (sample_count + prior_strength)
|
||||
reliability = 0.5 + shrinkage * (observed_win_rate - 0.5)
|
||||
|
||||
# Clamp to [0.0, 1.0] for safety (should already be in range when
|
||||
# observed_win_rate is in [0.0, 1.0], but guard against edge cases).
|
||||
return max(0.0, min(1.0, reliability))
|
||||
|
||||
|
||||
def compute_adjusted_evidence_weight(
|
||||
base_weight: float,
|
||||
reliability: float,
|
||||
) -> float:
|
||||
"""Adjusted weight = base_weight * (0.5 + reliability), clamped to [0.1, 2.0]."""
|
||||
adjusted = base_weight * (0.5 + reliability)
|
||||
return max(0.1, min(2.0, adjusted))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Query v_source_performance to get per-source win rates and sample counts.
|
||||
# Groups by source, counting total predictions and directional wins.
|
||||
_SOURCE_PERFORMANCE_SQL = """
|
||||
SELECT
|
||||
source,
|
||||
COUNT(*) AS sample_count,
|
||||
COUNT(*) FILTER (WHERE direction_correct = TRUE) AS win_count
|
||||
FROM v_source_performance
|
||||
WHERE direction_correct IS NOT NULL
|
||||
GROUP BY source
|
||||
"""
|
||||
|
||||
# Upsert into source_accuracy: update accuracy_ratio and sample_count
|
||||
# for existing sources, insert new ones.
|
||||
_UPSERT_SOURCE_ACCURACY_SQL = """
|
||||
INSERT INTO source_accuracy (source_id, accuracy_ratio, sample_count, last_updated)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
ON CONFLICT (source_id)
|
||||
DO UPDATE SET
|
||||
accuracy_ratio = EXCLUDED.accuracy_ratio,
|
||||
sample_count = EXCLUDED.sample_count,
|
||||
last_updated = NOW()
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database-backed function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def update_source_reliabilities(
|
||||
pool: asyncpg.Pool,
|
||||
) -> int:
|
||||
"""Recompute and store source reliability scores from latest outcomes.
|
||||
|
||||
1. Queries v_source_performance to get per-source win rates and counts
|
||||
2. Computes Bayesian shrinkage reliability for each source
|
||||
3. Upserts into source_accuracy table (accuracy_ratio = reliability)
|
||||
|
||||
Returns count of sources updated.
|
||||
"""
|
||||
try:
|
||||
rows = await pool.fetch(_SOURCE_PERFORMANCE_SQL)
|
||||
except Exception:
|
||||
logger.exception("Failed to query source performance for reliability update")
|
||||
return 0
|
||||
|
||||
if not rows:
|
||||
logger.info("No source performance data available for reliability update")
|
||||
return 0
|
||||
|
||||
updated = 0
|
||||
|
||||
for row in rows:
|
||||
source = row["source"]
|
||||
sample_count = row["sample_count"]
|
||||
win_count = row["win_count"]
|
||||
|
||||
observed_win_rate = win_count / sample_count if sample_count > 0 else 0.5
|
||||
reliability = compute_source_reliability(observed_win_rate, sample_count)
|
||||
|
||||
try:
|
||||
await pool.execute(
|
||||
_UPSERT_SOURCE_ACCURACY_SQL,
|
||||
source,
|
||||
reliability,
|
||||
sample_count,
|
||||
)
|
||||
updated += 1
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to upsert source reliability for source=%s", source
|
||||
)
|
||||
|
||||
logger.info("Updated source reliabilities for %d sources", updated)
|
||||
return updated
|
||||
@@ -0,0 +1,637 @@
|
||||
"""Metrics Engine — computes calibration, IC, Brier, and benchmark metrics.
|
||||
|
||||
Aggregates model quality metrics across configurable lookback windows and
|
||||
prediction horizons. Stores periodic snapshots for time-series analysis
|
||||
of model performance trends.
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 6.1, 6.2, 6.3, 6.4, 6.5,
|
||||
9.1, 9.2, 9.3, 9.4, 10.1, 10.2, 10.3, 10.4, 10.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CONFIDENCE_BUCKETS: list[tuple[float, float]] = [
|
||||
(0.50, 0.60),
|
||||
(0.60, 0.70),
|
||||
(0.70, 0.80),
|
||||
(0.80, 0.90),
|
||||
(0.90, 1.00),
|
||||
]
|
||||
|
||||
LOOKBACK_WINDOWS: list[str] = ["7d", "30d", "90d", "all"]
|
||||
|
||||
LOOKBACK_DURATIONS: dict[str, timedelta | None] = {
|
||||
"7d": timedelta(days=7),
|
||||
"30d": timedelta(days=30),
|
||||
"90d": timedelta(days=90),
|
||||
"all": None,
|
||||
}
|
||||
|
||||
EVALUATION_HORIZONS: list[str] = ["1h", "6h", "1d", "7d", "30d"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class CalibrationBucket:
|
||||
"""Calibration metrics for a single confidence bucket."""
|
||||
|
||||
bucket_low: float
|
||||
bucket_high: float
|
||||
avg_confidence: float
|
||||
observed_win_rate: float
|
||||
prediction_count: int
|
||||
miscalibrated: bool # |avg_confidence - win_rate| > 0.15
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelMetricSnapshot:
|
||||
"""Aggregate model quality metrics for a lookback/horizon combination."""
|
||||
|
||||
id: str
|
||||
generated_at: datetime
|
||||
lookback_window: str
|
||||
horizon: str
|
||||
prediction_count: int
|
||||
win_rate: float
|
||||
directional_accuracy: float
|
||||
information_coefficient: float | None
|
||||
rank_information_coefficient: float | None
|
||||
avg_return: float
|
||||
avg_excess_return_vs_spy: float
|
||||
avg_excess_return_vs_sector: float
|
||||
calibration_error: float # ECE
|
||||
brier_score: float
|
||||
buy_win_rate: float
|
||||
sell_win_rate: float
|
||||
hold_win_rate: float
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure computation functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_calibration_error(
|
||||
confidences: list[float],
|
||||
outcomes: list[bool],
|
||||
) -> tuple[float, list[CalibrationBucket]]:
|
||||
"""Compute ECE and calibration buckets.
|
||||
|
||||
ECE = Σ (n_b / N) * |avg_conf_b - win_rate_b|
|
||||
|
||||
Groups predictions into 5 confidence buckets and computes the weighted
|
||||
average of |avg_confidence - observed_win_rate| across all buckets.
|
||||
Flags buckets where |diff| > 0.15 as miscalibrated.
|
||||
|
||||
Returns (ece, buckets). Returns (0.0, []) when no data is provided.
|
||||
"""
|
||||
if not confidences or not outcomes:
|
||||
return 0.0, []
|
||||
|
||||
n = len(confidences)
|
||||
buckets: list[CalibrationBucket] = []
|
||||
ece = 0.0
|
||||
|
||||
for low, high in CONFIDENCE_BUCKETS:
|
||||
bucket_confs: list[float] = []
|
||||
bucket_outcomes: list[bool] = []
|
||||
|
||||
for conf, outcome in zip(confidences, outcomes):
|
||||
# Last bucket is inclusive on the right: [0.90, 1.00]
|
||||
if high == 1.00:
|
||||
in_bucket = low <= conf <= high
|
||||
else:
|
||||
in_bucket = low <= conf < high
|
||||
|
||||
if in_bucket:
|
||||
bucket_confs.append(conf)
|
||||
bucket_outcomes.append(outcome)
|
||||
|
||||
count = len(bucket_confs)
|
||||
if count == 0:
|
||||
# Empty bucket — exclude from ECE, still record it
|
||||
buckets.append(
|
||||
CalibrationBucket(
|
||||
bucket_low=low,
|
||||
bucket_high=high,
|
||||
avg_confidence=0.0,
|
||||
observed_win_rate=0.0,
|
||||
prediction_count=0,
|
||||
miscalibrated=False,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
avg_conf = sum(bucket_confs) / count
|
||||
win_rate = sum(1.0 for o in bucket_outcomes if o) / count
|
||||
diff = abs(avg_conf - win_rate)
|
||||
miscalibrated = diff > 0.15
|
||||
|
||||
buckets.append(
|
||||
CalibrationBucket(
|
||||
bucket_low=low,
|
||||
bucket_high=high,
|
||||
avg_confidence=avg_conf,
|
||||
observed_win_rate=win_rate,
|
||||
prediction_count=count,
|
||||
miscalibrated=miscalibrated,
|
||||
)
|
||||
)
|
||||
|
||||
ece += (count / n) * diff
|
||||
|
||||
return ece, buckets
|
||||
|
||||
|
||||
def compute_brier_score(
|
||||
p_bulls: list[float],
|
||||
outcomes: list[bool],
|
||||
) -> float:
|
||||
"""Brier score = mean((p_bull - outcome)^2).
|
||||
|
||||
outcome is 1.0 when price moved in predicted direction, 0.0 otherwise.
|
||||
Returns value in [0.0, 1.0]. Returns 0.0 for empty input.
|
||||
"""
|
||||
if not p_bulls or not outcomes:
|
||||
return 0.0
|
||||
|
||||
n = len(p_bulls)
|
||||
total = 0.0
|
||||
for p, o in zip(p_bulls, outcomes):
|
||||
actual = 1.0 if o else 0.0
|
||||
total += (p - actual) ** 2
|
||||
|
||||
return total / n
|
||||
|
||||
|
||||
def _pearson_correlation(xs: list[float], ys: list[float]) -> float | None:
|
||||
"""Compute Pearson correlation coefficient between two lists.
|
||||
|
||||
Returns None if the lists have fewer than 2 elements or if either
|
||||
has zero variance. Guards against NaN/infinity.
|
||||
"""
|
||||
n = len(xs)
|
||||
if n < 2:
|
||||
return None
|
||||
|
||||
mean_x = sum(xs) / n
|
||||
mean_y = sum(ys) / n
|
||||
|
||||
cov = 0.0
|
||||
var_x = 0.0
|
||||
var_y = 0.0
|
||||
|
||||
for x, y in zip(xs, ys):
|
||||
dx = x - mean_x
|
||||
dy = y - mean_y
|
||||
cov += dx * dy
|
||||
var_x += dx * dx
|
||||
var_y += dy * dy
|
||||
|
||||
if var_x == 0.0 or var_y == 0.0:
|
||||
return None
|
||||
|
||||
r = cov / math.sqrt(var_x * var_y)
|
||||
|
||||
# Guard against floating-point drift
|
||||
if math.isnan(r) or math.isinf(r):
|
||||
return None
|
||||
|
||||
# Clamp to [-1.0, 1.0]
|
||||
return max(-1.0, min(1.0, r))
|
||||
|
||||
|
||||
def _rank_data(values: list[float]) -> list[float]:
|
||||
"""Compute fractional ranks for a list of values (average tie-breaking)."""
|
||||
n = len(values)
|
||||
indexed = sorted(range(n), key=lambda i: values[i])
|
||||
|
||||
ranks = [0.0] * n
|
||||
i = 0
|
||||
while i < n:
|
||||
# Find the end of the tie group
|
||||
j = i + 1
|
||||
while j < n and values[indexed[j]] == values[indexed[i]]:
|
||||
j += 1
|
||||
|
||||
# Average rank for the tie group (1-based)
|
||||
avg_rank = (i + j + 1) / 2.0
|
||||
for k in range(i, j):
|
||||
ranks[indexed[k]] = avg_rank
|
||||
|
||||
i = j
|
||||
|
||||
return ranks
|
||||
|
||||
|
||||
def compute_information_coefficient(
|
||||
scores: list[float],
|
||||
returns: list[float],
|
||||
) -> float | None:
|
||||
"""Pearson correlation between prediction scores and future returns.
|
||||
|
||||
Returns None when fewer than 30 data points.
|
||||
Returns value in [-1.0, 1.0].
|
||||
"""
|
||||
if len(scores) < 30 or len(returns) < 30:
|
||||
return None
|
||||
|
||||
n = min(len(scores), len(returns))
|
||||
return _pearson_correlation(scores[:n], returns[:n])
|
||||
|
||||
|
||||
def compute_rank_information_coefficient(
|
||||
scores: list[float],
|
||||
returns: list[float],
|
||||
) -> float | None:
|
||||
"""Spearman rank correlation between prediction scores and future returns.
|
||||
|
||||
Ranks the data and computes Pearson correlation on the ranks.
|
||||
Returns None when fewer than 30 data points.
|
||||
Returns value in [-1.0, 1.0].
|
||||
"""
|
||||
if len(scores) < 30 or len(returns) < 30:
|
||||
return None
|
||||
|
||||
n = min(len(scores), len(returns))
|
||||
ranked_scores = _rank_data(scores[:n])
|
||||
ranked_returns = _rank_data(returns[:n])
|
||||
|
||||
return _pearson_correlation(ranked_scores, ranked_returns)
|
||||
|
||||
|
||||
def compute_contribution_scores(
|
||||
weights: list[float],
|
||||
) -> list[float]:
|
||||
"""Compute contribution scores from document weights.
|
||||
|
||||
Each score = weight_i / sum(weights). Sums to 1.0.
|
||||
Each score in [0.0, 1.0].
|
||||
Returns empty list for empty input.
|
||||
"""
|
||||
if not weights:
|
||||
return []
|
||||
|
||||
total = sum(weights)
|
||||
if total == 0.0:
|
||||
n = len(weights)
|
||||
return [1.0 / n] * n
|
||||
|
||||
return [w / total for w in weights]
|
||||
|
||||
|
||||
def compute_hit_rate_improvement(win_rate: float) -> float:
|
||||
"""Hit rate improvement over random 50/50 baseline.
|
||||
|
||||
Defined as (system_win_rate - 0.5) / 0.5.
|
||||
"""
|
||||
return (win_rate - 0.5) / 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL queries for v_prediction_performance view
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PERFORMANCE_DATA_SQL = """
|
||||
SELECT
|
||||
ticker,
|
||||
direction,
|
||||
action,
|
||||
confidence,
|
||||
strength,
|
||||
p_bull,
|
||||
score_company,
|
||||
score_macro,
|
||||
score_competitive,
|
||||
future_return,
|
||||
excess_return_vs_spy,
|
||||
excess_return_vs_sector,
|
||||
direction_correct,
|
||||
profitable,
|
||||
horizon,
|
||||
generated_at
|
||||
FROM v_prediction_performance
|
||||
WHERE horizon = $1
|
||||
"""
|
||||
|
||||
_PERFORMANCE_DATA_WITH_LOOKBACK_SQL = """
|
||||
SELECT
|
||||
ticker,
|
||||
direction,
|
||||
action,
|
||||
confidence,
|
||||
strength,
|
||||
p_bull,
|
||||
score_company,
|
||||
score_macro,
|
||||
score_competitive,
|
||||
future_return,
|
||||
excess_return_vs_spy,
|
||||
excess_return_vs_sector,
|
||||
direction_correct,
|
||||
profitable,
|
||||
horizon,
|
||||
generated_at
|
||||
FROM v_prediction_performance
|
||||
WHERE horizon = $1
|
||||
AND generated_at >= $2
|
||||
"""
|
||||
|
||||
_INSERT_METRIC_SNAPSHOT_SQL = """
|
||||
INSERT INTO model_metric_snapshots (
|
||||
id, generated_at, lookback_window, horizon,
|
||||
prediction_count, win_rate, directional_accuracy,
|
||||
information_coefficient, rank_information_coefficient,
|
||||
avg_return, avg_excess_return_vs_spy, avg_excess_return_vs_sector,
|
||||
calibration_error, brier_score,
|
||||
buy_win_rate, sell_win_rate, hold_win_rate,
|
||||
metadata
|
||||
) VALUES (
|
||||
$1::uuid, $2, $3, $4,
|
||||
$5, $6, $7,
|
||||
$8, $9,
|
||||
$10, $11, $12,
|
||||
$13, $14,
|
||||
$15, $16, $17,
|
||||
$18::jsonb
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metric computation from raw rows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_metrics_from_rows(
|
||||
rows: list[dict],
|
||||
lookback_window: str,
|
||||
horizon: str,
|
||||
) -> ModelMetricSnapshot:
|
||||
"""Compute all metrics from a list of prediction performance rows.
|
||||
|
||||
Returns a ModelMetricSnapshot with all computed metrics.
|
||||
"""
|
||||
now = datetime.now().astimezone()
|
||||
snapshot_id = str(uuid.uuid4())
|
||||
|
||||
prediction_count = len(rows)
|
||||
|
||||
if prediction_count == 0:
|
||||
return ModelMetricSnapshot(
|
||||
id=snapshot_id,
|
||||
generated_at=now,
|
||||
lookback_window=lookback_window,
|
||||
horizon=horizon,
|
||||
prediction_count=0,
|
||||
win_rate=0.0,
|
||||
directional_accuracy=0.0,
|
||||
information_coefficient=None,
|
||||
rank_information_coefficient=None,
|
||||
avg_return=0.0,
|
||||
avg_excess_return_vs_spy=0.0,
|
||||
avg_excess_return_vs_sector=0.0,
|
||||
calibration_error=0.0,
|
||||
brier_score=0.0,
|
||||
buy_win_rate=0.0,
|
||||
sell_win_rate=0.0,
|
||||
hold_win_rate=0.0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
# --- Win rate and directional accuracy ---
|
||||
direction_correct_count = sum(
|
||||
1 for r in rows if r.get("direction_correct") is True
|
||||
)
|
||||
win_rate = direction_correct_count / prediction_count
|
||||
directional_accuracy = win_rate # Same metric, different name
|
||||
|
||||
# --- Per-action win rates ---
|
||||
buy_rows = [r for r in rows if (r.get("action") or "").lower() == "buy"]
|
||||
sell_rows = [r for r in rows if (r.get("action") or "").lower() == "sell"]
|
||||
hold_rows = [r for r in rows if (r.get("action") or "").lower() == "hold"]
|
||||
|
||||
buy_win_rate = (
|
||||
sum(1 for r in buy_rows if r.get("direction_correct") is True) / len(buy_rows)
|
||||
if buy_rows
|
||||
else 0.0
|
||||
)
|
||||
sell_win_rate = (
|
||||
sum(1 for r in sell_rows if r.get("direction_correct") is True)
|
||||
/ len(sell_rows)
|
||||
if sell_rows
|
||||
else 0.0
|
||||
)
|
||||
hold_win_rate = (
|
||||
sum(1 for r in hold_rows if r.get("direction_correct") is True)
|
||||
/ len(hold_rows)
|
||||
if hold_rows
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# --- Average return ---
|
||||
returns_list = [
|
||||
r["future_return"] for r in rows if r.get("future_return") is not None
|
||||
]
|
||||
avg_return = sum(returns_list) / len(returns_list) if returns_list else 0.0
|
||||
|
||||
# --- Average excess return vs SPY (Requirement 9.1) ---
|
||||
excess_spy_list = [
|
||||
r["excess_return_vs_spy"]
|
||||
for r in rows
|
||||
if r.get("excess_return_vs_spy") is not None
|
||||
]
|
||||
avg_excess_return_vs_spy = (
|
||||
sum(excess_spy_list) / len(excess_spy_list) if excess_spy_list else 0.0
|
||||
)
|
||||
|
||||
# --- Average excess return vs sector ETF (Requirement 9.2) ---
|
||||
excess_sector_list = [
|
||||
r["excess_return_vs_sector"]
|
||||
for r in rows
|
||||
if r.get("excess_return_vs_sector") is not None
|
||||
]
|
||||
avg_excess_return_vs_sector = (
|
||||
sum(excess_sector_list) / len(excess_sector_list)
|
||||
if excess_sector_list
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# --- Calibration error (ECE) (Requirements 5.1, 5.2, 5.3, 5.5) ---
|
||||
confidences = [
|
||||
r["confidence"] for r in rows if r.get("confidence") is not None
|
||||
]
|
||||
outcomes = [
|
||||
r.get("direction_correct") is True
|
||||
for r in rows
|
||||
if r.get("confidence") is not None
|
||||
]
|
||||
ece, _buckets = compute_calibration_error(confidences, outcomes)
|
||||
|
||||
# --- Brier score (Requirement 5.4) ---
|
||||
p_bulls = [r["p_bull"] for r in rows if r.get("p_bull") is not None]
|
||||
brier_outcomes = [
|
||||
r.get("direction_correct") is True
|
||||
for r in rows
|
||||
if r.get("p_bull") is not None
|
||||
]
|
||||
brier = compute_brier_score(p_bulls, brier_outcomes)
|
||||
|
||||
# --- Information Coefficient (Requirements 6.1, 6.5) ---
|
||||
ic_scores = [
|
||||
r["strength"] for r in rows if r.get("strength") is not None
|
||||
and r.get("future_return") is not None
|
||||
]
|
||||
ic_returns = [
|
||||
r["future_return"] for r in rows if r.get("strength") is not None
|
||||
and r.get("future_return") is not None
|
||||
]
|
||||
ic = compute_information_coefficient(ic_scores, ic_returns)
|
||||
|
||||
# --- Rank Information Coefficient (Requirements 6.2, 6.5) ---
|
||||
rank_ic = compute_rank_information_coefficient(ic_scores, ic_returns)
|
||||
|
||||
# --- Hit rate improvement (Requirement 9.4) ---
|
||||
hit_rate_improvement = compute_hit_rate_improvement(win_rate)
|
||||
|
||||
# --- Metadata (Requirement 10.5) ---
|
||||
metadata: dict = {
|
||||
"hit_rate_improvement": hit_rate_improvement,
|
||||
"buy_count": len(buy_rows),
|
||||
"sell_count": len(sell_rows),
|
||||
"hold_count": len(hold_rows),
|
||||
"returns_count": len(returns_list),
|
||||
"excess_spy_count": len(excess_spy_list),
|
||||
"excess_sector_count": len(excess_sector_list),
|
||||
}
|
||||
|
||||
return ModelMetricSnapshot(
|
||||
id=snapshot_id,
|
||||
generated_at=now,
|
||||
lookback_window=lookback_window,
|
||||
horizon=horizon,
|
||||
prediction_count=prediction_count,
|
||||
win_rate=win_rate,
|
||||
directional_accuracy=directional_accuracy,
|
||||
information_coefficient=ic,
|
||||
rank_information_coefficient=rank_ic,
|
||||
avg_return=avg_return,
|
||||
avg_excess_return_vs_spy=avg_excess_return_vs_spy,
|
||||
avg_excess_return_vs_sector=avg_excess_return_vs_sector,
|
||||
calibration_error=ece,
|
||||
brier_score=brier,
|
||||
buy_win_rate=buy_win_rate,
|
||||
sell_win_rate=sell_win_rate,
|
||||
hold_win_rate=hold_win_rate,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point (Requirements 10.1, 10.2, 10.3, 10.4, 10.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def compute_and_store_metric_snapshots(
|
||||
pool: asyncpg.Pool,
|
||||
) -> list[ModelMetricSnapshot]:
|
||||
"""Compute metric snapshots for all lookback/horizon combinations.
|
||||
|
||||
Lookback windows: 7d, 30d, 90d, all-time.
|
||||
Horizons: 1h, 6h, 1d, 7d, 30d.
|
||||
|
||||
For each of the 4 lookbacks × 5 horizons = 20 combinations, queries the
|
||||
v_prediction_performance view, computes all metrics, and persists the
|
||||
result to model_metric_snapshots.
|
||||
|
||||
Returns the list of computed snapshots.
|
||||
"""
|
||||
snapshots: list[ModelMetricSnapshot] = []
|
||||
now = datetime.now().astimezone()
|
||||
|
||||
for lookback in LOOKBACK_WINDOWS:
|
||||
duration = LOOKBACK_DURATIONS[lookback]
|
||||
|
||||
for horizon in EVALUATION_HORIZONS:
|
||||
try:
|
||||
# Query performance data
|
||||
if duration is not None:
|
||||
cutoff = now - duration
|
||||
rows = await pool.fetch(
|
||||
_PERFORMANCE_DATA_WITH_LOOKBACK_SQL,
|
||||
horizon,
|
||||
cutoff,
|
||||
)
|
||||
else:
|
||||
rows = await pool.fetch(
|
||||
_PERFORMANCE_DATA_SQL,
|
||||
horizon,
|
||||
)
|
||||
|
||||
# Convert asyncpg Records to dicts
|
||||
row_dicts = [dict(r) for r in rows]
|
||||
|
||||
# Compute metrics
|
||||
snapshot = _compute_metrics_from_rows(
|
||||
row_dicts, lookback, horizon
|
||||
)
|
||||
|
||||
# Persist
|
||||
await pool.execute(
|
||||
_INSERT_METRIC_SNAPSHOT_SQL,
|
||||
snapshot.id,
|
||||
snapshot.generated_at,
|
||||
snapshot.lookback_window,
|
||||
snapshot.horizon,
|
||||
snapshot.prediction_count,
|
||||
snapshot.win_rate,
|
||||
snapshot.directional_accuracy,
|
||||
snapshot.information_coefficient,
|
||||
snapshot.rank_information_coefficient,
|
||||
snapshot.avg_return,
|
||||
snapshot.avg_excess_return_vs_spy,
|
||||
snapshot.avg_excess_return_vs_sector,
|
||||
snapshot.calibration_error,
|
||||
snapshot.brier_score,
|
||||
snapshot.buy_win_rate,
|
||||
snapshot.sell_win_rate,
|
||||
snapshot.hold_win_rate,
|
||||
json.dumps(snapshot.metadata),
|
||||
)
|
||||
|
||||
snapshots.append(snapshot)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to compute metrics for lookback=%s horizon=%s",
|
||||
lookback,
|
||||
horizon,
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Computed %d metric snapshots across %d lookback/horizon combinations",
|
||||
len(snapshots),
|
||||
len(LOOKBACK_WINDOWS) * len(EVALUATION_HORIZONS),
|
||||
)
|
||||
|
||||
return snapshots
|
||||
@@ -0,0 +1,414 @@
|
||||
"""Outcome Evaluator — matches predictions with realized market outcomes.
|
||||
|
||||
Runs periodically to evaluate prediction snapshots whose horizon has elapsed.
|
||||
For each snapshot, fetches future prices at the horizon endpoint and computes
|
||||
returns, excess returns, directional accuracy, and profitability across all
|
||||
five evaluation horizons (1h, 6h, 1d, 7d, 30d).
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 4.10
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import asyncpg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
HORIZON_DURATIONS: dict[str, timedelta] = {
|
||||
"1h": timedelta(hours=1),
|
||||
"6h": timedelta(hours=6),
|
||||
"1d": timedelta(days=1),
|
||||
"7d": timedelta(days=7),
|
||||
"30d": timedelta(days=30),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionOutcome:
|
||||
"""Realized outcome for a prediction at a specific horizon."""
|
||||
|
||||
id: str # UUID
|
||||
prediction_id: str
|
||||
evaluated_at: datetime
|
||||
horizon: str # 1h, 6h, 1d, 7d, 30d
|
||||
future_price: float
|
||||
future_return: float
|
||||
spy_future_price: float | None
|
||||
spy_return: float | None
|
||||
sector_etf_future_price: float | None
|
||||
sector_etf_return: float | None
|
||||
excess_return_vs_spy: float | None
|
||||
excess_return_vs_sector: float | None
|
||||
direction_correct: bool
|
||||
profitable: bool
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL statements
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Find matured predictions: snapshots where generated_at + horizon_duration <= NOW()
|
||||
# and no outcome has been recorded yet for that (prediction_id, horizon) pair.
|
||||
# We evaluate ALL 5 horizons for each snapshot, not just the snapshot's own horizon.
|
||||
_MATURED_PREDICTIONS_SQL = """
|
||||
SELECT
|
||||
ps.id,
|
||||
ps.generated_at,
|
||||
ps.ticker,
|
||||
ps.horizon AS snapshot_horizon,
|
||||
ps.direction,
|
||||
ps.action,
|
||||
ps.price_at_prediction,
|
||||
ps.spy_price_at_prediction,
|
||||
ps.sector_etf_price_at_prediction
|
||||
FROM prediction_snapshots ps
|
||||
WHERE ps.generated_at + $1::interval <= NOW()
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM prediction_outcomes po
|
||||
WHERE po.prediction_id = ps.id AND po.horizon = $2
|
||||
)
|
||||
"""
|
||||
|
||||
# Fetch the close price for a ticker at or before a specific time.
|
||||
# Uses the closest bar before or at the target time.
|
||||
_CLOSE_AT_TIME_SQL = """
|
||||
SELECT (data->>'c')::float AS close
|
||||
FROM market_snapshots
|
||||
WHERE ticker = $1
|
||||
AND snapshot_type = 'bar'
|
||||
AND data->>'c' IS NOT NULL
|
||||
AND captured_at <= $2
|
||||
ORDER BY captured_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
_INSERT_OUTCOME_SQL = """
|
||||
INSERT INTO prediction_outcomes (
|
||||
id, prediction_id, evaluated_at, horizon,
|
||||
future_price, future_return,
|
||||
spy_future_price, spy_return,
|
||||
sector_etf_future_price, sector_etf_return,
|
||||
excess_return_vs_spy, excess_return_vs_sector,
|
||||
direction_correct, profitable,
|
||||
metadata
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3, $4,
|
||||
$5, $6,
|
||||
$7, $8,
|
||||
$9, $10,
|
||||
$11, $12,
|
||||
$13, $14,
|
||||
$15::jsonb
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Price fetching at a specific time
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _fetch_close_at_time(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
target_time: datetime,
|
||||
) -> float | None:
|
||||
"""Fetch the close price for a ticker at or before a specific time.
|
||||
|
||||
Returns None if no market data is available before the target time.
|
||||
"""
|
||||
row = await pool.fetchrow(_CLOSE_AT_TIME_SQL, ticker, target_time)
|
||||
if row is None:
|
||||
return None
|
||||
return row["close"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sector ETF lookup (reuse pattern from prediction_snapshot)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SECTOR_ETF_MAP: dict[str, str] = {
|
||||
"Technology": "XLK",
|
||||
"Consumer Cyclical": "XLY",
|
||||
"Financial Services": "XLF",
|
||||
"Healthcare": "XLV",
|
||||
"Energy": "XLE",
|
||||
"Communication Services": "XLC",
|
||||
"Industrials": "XLI",
|
||||
"Consumer Defensive": "XLP",
|
||||
"Real Estate": "XLRE",
|
||||
"Utilities": "XLU",
|
||||
}
|
||||
|
||||
_COMPANY_SECTOR_SQL = """
|
||||
SELECT sector FROM companies WHERE ticker = $1 AND active = TRUE LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
async def _fetch_sector_etf_ticker(pool: asyncpg.Pool, ticker: str) -> str | None:
|
||||
"""Look up the sector ETF ticker for a company ticker."""
|
||||
row = await pool.fetchrow(_COMPANY_SECTOR_SQL, ticker)
|
||||
if row is None or row["sector"] is None:
|
||||
return None
|
||||
return _SECTOR_ETF_MAP.get(row["sector"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Return computation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_return(current_price: float, future_price: float) -> float:
|
||||
"""Compute simple return: (future - current) / current."""
|
||||
if current_price == 0.0:
|
||||
return 0.0
|
||||
return (future_price - current_price) / current_price
|
||||
|
||||
|
||||
def _is_direction_correct(direction: str, future_return: float) -> bool:
|
||||
"""Determine if the predicted direction matches the realized return.
|
||||
|
||||
bullish + positive return = True
|
||||
bearish + negative return = True
|
||||
All other combinations = False
|
||||
"""
|
||||
direction_lower = direction.lower()
|
||||
if direction_lower == "bullish" and future_return > 0.0:
|
||||
return True
|
||||
if direction_lower == "bearish" and future_return < 0.0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_profitable(action: str, future_return: float) -> bool:
|
||||
"""Determine if the predicted action would have been profitable.
|
||||
|
||||
buy + positive return = True
|
||||
sell + negative return = True
|
||||
All other combinations = False
|
||||
"""
|
||||
action_lower = action.lower()
|
||||
if action_lower == "buy" and future_return > 0.0:
|
||||
return True
|
||||
if action_lower == "sell" and future_return < 0.0:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single prediction evaluation (Requirements 4.2–4.7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def evaluate_single_prediction(
|
||||
pool: asyncpg.Pool,
|
||||
snapshot: dict,
|
||||
horizon: str,
|
||||
) -> PredictionOutcome | None:
|
||||
"""Evaluate a single prediction at a specific horizon.
|
||||
|
||||
Fetches the future price at generated_at + horizon_duration for the ticker,
|
||||
SPY, and sector ETF. Computes returns, excess returns, direction correctness,
|
||||
and profitability.
|
||||
|
||||
Returns None if the ticker's future price is unavailable (Requirement 4.10).
|
||||
"""
|
||||
duration = HORIZON_DURATIONS[horizon]
|
||||
target_time = snapshot["generated_at"] + duration
|
||||
ticker = snapshot["ticker"]
|
||||
|
||||
# Fetch future price for the ticker — required (skip if unavailable)
|
||||
future_price = await _fetch_close_at_time(pool, ticker, target_time)
|
||||
if future_price is None:
|
||||
logger.debug(
|
||||
"Future price unavailable for %s at horizon %s (target %s), skipping",
|
||||
ticker,
|
||||
horizon,
|
||||
target_time,
|
||||
)
|
||||
return None
|
||||
|
||||
price_at_prediction = snapshot["price_at_prediction"]
|
||||
if price_at_prediction is None or price_at_prediction == 0.0:
|
||||
logger.warning(
|
||||
"Price at prediction is NULL or zero for snapshot %s, skipping horizon %s",
|
||||
snapshot["id"],
|
||||
horizon,
|
||||
)
|
||||
return None
|
||||
|
||||
# Compute ticker future return (Requirement 4.2)
|
||||
future_return = _compute_return(price_at_prediction, future_price)
|
||||
|
||||
# Fetch SPY future price and compute SPY return (Requirement 4.3)
|
||||
spy_future_price: float | None = None
|
||||
spy_return: float | None = None
|
||||
spy_price_at_prediction = snapshot["spy_price_at_prediction"]
|
||||
|
||||
if spy_price_at_prediction is not None and spy_price_at_prediction != 0.0:
|
||||
spy_future_price = await _fetch_close_at_time(pool, "SPY", target_time)
|
||||
if spy_future_price is not None:
|
||||
spy_return = _compute_return(spy_price_at_prediction, spy_future_price)
|
||||
|
||||
# Fetch sector ETF future price and compute sector return (Requirement 4.4)
|
||||
sector_etf_future_price: float | None = None
|
||||
sector_etf_return: float | None = None
|
||||
sector_etf_price_at_prediction = snapshot["sector_etf_price_at_prediction"]
|
||||
|
||||
if (
|
||||
sector_etf_price_at_prediction is not None
|
||||
and sector_etf_price_at_prediction != 0.0
|
||||
):
|
||||
sector_etf_ticker = await _fetch_sector_etf_ticker(pool, ticker)
|
||||
if sector_etf_ticker is not None:
|
||||
sector_etf_future_price = await _fetch_close_at_time(
|
||||
pool, sector_etf_ticker, target_time
|
||||
)
|
||||
if sector_etf_future_price is not None:
|
||||
sector_etf_return = _compute_return(
|
||||
sector_etf_price_at_prediction, sector_etf_future_price
|
||||
)
|
||||
|
||||
# Compute excess returns (Requirement 4.5)
|
||||
excess_return_vs_spy: float | None = None
|
||||
if future_return is not None and spy_return is not None:
|
||||
excess_return_vs_spy = future_return - spy_return
|
||||
|
||||
excess_return_vs_sector: float | None = None
|
||||
if future_return is not None and sector_etf_return is not None:
|
||||
excess_return_vs_sector = future_return - sector_etf_return
|
||||
|
||||
# Determine direction correctness (Requirement 4.6)
|
||||
direction_correct = _is_direction_correct(snapshot["direction"], future_return)
|
||||
|
||||
# Determine profitability (Requirement 4.7)
|
||||
profitable = _is_profitable(snapshot["action"], future_return)
|
||||
|
||||
now = datetime.now().astimezone()
|
||||
|
||||
return PredictionOutcome(
|
||||
id=str(uuid.uuid4()),
|
||||
prediction_id=str(snapshot["id"]),
|
||||
evaluated_at=now,
|
||||
horizon=horizon,
|
||||
future_price=future_price,
|
||||
future_return=future_return,
|
||||
spy_future_price=spy_future_price,
|
||||
spy_return=spy_return,
|
||||
sector_etf_future_price=sector_etf_future_price,
|
||||
sector_etf_return=sector_etf_return,
|
||||
excess_return_vs_spy=excess_return_vs_spy,
|
||||
excess_return_vs_sector=excess_return_vs_sector,
|
||||
direction_correct=direction_correct,
|
||||
profitable=profitable,
|
||||
metadata={
|
||||
"ticker": ticker,
|
||||
"horizon": horizon,
|
||||
"price_at_prediction": price_at_prediction,
|
||||
"future_price": future_price,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store outcome (Requirement 4.9)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _store_outcome(
|
||||
conn: asyncpg.Connection,
|
||||
outcome: PredictionOutcome,
|
||||
) -> None:
|
||||
"""Persist a single prediction outcome to the database."""
|
||||
await conn.execute(
|
||||
_INSERT_OUTCOME_SQL,
|
||||
outcome.id,
|
||||
outcome.prediction_id,
|
||||
outcome.evaluated_at,
|
||||
outcome.horizon,
|
||||
outcome.future_price,
|
||||
outcome.future_return,
|
||||
outcome.spy_future_price,
|
||||
outcome.spy_return,
|
||||
outcome.sector_etf_future_price,
|
||||
outcome.sector_etf_return,
|
||||
outcome.excess_return_vs_spy,
|
||||
outcome.excess_return_vs_sector,
|
||||
outcome.direction_correct,
|
||||
outcome.profitable,
|
||||
json.dumps(outcome.metadata),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point (Requirements 4.1, 4.8, 4.9, 4.10)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def evaluate_matured_predictions(
|
||||
pool: asyncpg.Pool,
|
||||
) -> int:
|
||||
"""Evaluate all matured prediction snapshots across all horizons.
|
||||
|
||||
For each of the 5 horizons (1h, 6h, 1d, 7d, 30d), finds prediction
|
||||
snapshots where generated_at + horizon_duration <= NOW() and no outcome
|
||||
has been recorded for that (prediction_id, horizon) pair.
|
||||
|
||||
For each matured snapshot-horizon pair, fetches future prices and computes
|
||||
returns. Skips horizons where the future price is unavailable — those will
|
||||
be retried on the next run (Requirement 4.10).
|
||||
|
||||
Returns the total count of outcomes recorded.
|
||||
"""
|
||||
total_recorded = 0
|
||||
|
||||
for horizon, duration in HORIZON_DURATIONS.items():
|
||||
# Find snapshots matured for this horizon
|
||||
rows = await pool.fetch(_MATURED_PREDICTIONS_SQL, duration, horizon)
|
||||
|
||||
if not rows:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Found %d matured predictions for horizon %s", len(rows), horizon
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
snapshot = dict(row)
|
||||
try:
|
||||
outcome = await evaluate_single_prediction(pool, snapshot, horizon)
|
||||
if outcome is None:
|
||||
# Future price unavailable — skip, retry next run
|
||||
continue
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await _store_outcome(conn, outcome)
|
||||
|
||||
total_recorded += 1
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to evaluate snapshot %s at horizon %s",
|
||||
snapshot["id"],
|
||||
horizon,
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info("Outcome evaluation complete: %d outcomes recorded", total_recorded)
|
||||
return total_recorded
|
||||
@@ -0,0 +1,540 @@
|
||||
"""Prediction Snapshot Writer — captures immutable prediction state at generation time.
|
||||
|
||||
Creates frozen records of every recommendation with prices, evidence links,
|
||||
duplicate detection, and contribution scores so that predictions can be
|
||||
evaluated against future outcomes without hindsight bias.
|
||||
|
||||
Requirements: 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 3.1, 3.2, 3.3, 3.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.schemas import Recommendation, TrendSummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SECTOR_ETF_MAP: dict[str, str] = {
|
||||
"Technology": "XLK",
|
||||
"Consumer Cyclical": "XLY",
|
||||
"Financial Services": "XLF",
|
||||
"Healthcare": "XLV",
|
||||
"Energy": "XLE",
|
||||
"Communication Services": "XLC",
|
||||
"Industrials": "XLI",
|
||||
"Consumer Defensive": "XLP",
|
||||
"Real Estate": "XLRE",
|
||||
"Utilities": "XLU",
|
||||
}
|
||||
|
||||
EVALUATION_HORIZONS: list[str] = ["1h", "6h", "1d", "7d", "30d"]
|
||||
|
||||
MAX_SINGLE_DOCUMENT_WEIGHT: float = 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionSnapshot:
|
||||
"""Immutable snapshot of a prediction at generation time."""
|
||||
|
||||
id: str # UUID
|
||||
generated_at: datetime
|
||||
ticker: str
|
||||
window: str
|
||||
horizon: str
|
||||
direction: str # bullish/bearish/mixed/neutral
|
||||
action: str # buy/sell/hold/watch
|
||||
mode: str # informational/paper_eligible/live_eligible
|
||||
strength: float
|
||||
confidence: float
|
||||
contradiction: float
|
||||
p_bull: float | None
|
||||
p_bear: float | None
|
||||
score_company: float
|
||||
score_macro: float
|
||||
score_competitive: float
|
||||
evidence_count: int
|
||||
unique_source_count: int
|
||||
duplicate_evidence_count: int
|
||||
price_at_prediction: float | None
|
||||
spy_price_at_prediction: float | None
|
||||
sector_etf_price_at_prediction: float | None
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SignalEvidenceLink:
|
||||
"""Link between a prediction and a contributing evidence document."""
|
||||
|
||||
id: str # UUID
|
||||
prediction_id: str
|
||||
document_id: str
|
||||
signal_id: str
|
||||
ticker: str
|
||||
source: str
|
||||
source_type: str
|
||||
catalyst_type: str
|
||||
sentiment: str
|
||||
impact: float
|
||||
extraction_confidence: float
|
||||
weight: float # clamped to MAX_SINGLE_DOCUMENT_WEIGHT
|
||||
is_duplicate: bool
|
||||
canonical_evidence_key: str
|
||||
contribution_score: float # weight / total_weight, sums to 1.0
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Canonical evidence key computation (Requirements 2.3, 17.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_canonical_evidence_key(title: str, url: str) -> str:
|
||||
"""SHA256 of normalized(title) + normalized(url).
|
||||
|
||||
Normalization:
|
||||
- Title: lowercase, strip leading/trailing whitespace
|
||||
- URL: lowercase, strip query parameters (keep scheme, netloc, path)
|
||||
"""
|
||||
normalized_title = title.strip().lower()
|
||||
|
||||
parsed = urllib.parse.urlparse(url.lower())
|
||||
normalized_url = urllib.parse.urlunparse(
|
||||
(parsed.scheme, parsed.netloc, parsed.path, "", "", "")
|
||||
)
|
||||
|
||||
combined = normalized_title + normalized_url
|
||||
return hashlib.sha256(combined.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Contribution score computation (Requirements 2.5, 17.7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_contribution_scores(weights: list[float]) -> list[float]:
|
||||
"""Compute contribution scores: each score = weight_i / sum(weights).
|
||||
|
||||
All scores are in [0.0, 1.0] and sum to 1.0 (within floating-point tolerance).
|
||||
Returns an empty list for empty input.
|
||||
"""
|
||||
if not weights:
|
||||
return []
|
||||
|
||||
total = sum(weights)
|
||||
if total == 0.0:
|
||||
# All weights are zero — distribute equally
|
||||
n = len(weights)
|
||||
return [1.0 / n] * n
|
||||
|
||||
return [w / total for w in weights]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Price fetching (Requirements 1.2, 1.3, 1.4, 1.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LATEST_CLOSE_SQL = """
|
||||
SELECT (data->>'c')::float AS close
|
||||
FROM market_snapshots
|
||||
WHERE ticker = $1 AND snapshot_type = 'bar' AND data->>'c' IS NOT NULL
|
||||
ORDER BY captured_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_latest_close_price(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
) -> float | None:
|
||||
"""Fetch most recent close price from market_snapshots for a ticker.
|
||||
|
||||
Returns None if no market data is available for the ticker.
|
||||
"""
|
||||
row = await pool.fetchrow(_LATEST_CLOSE_SQL, ticker)
|
||||
if row is None:
|
||||
return None
|
||||
return row["close"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sector ETF lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMPANY_SECTOR_SQL = """
|
||||
SELECT sector FROM companies WHERE ticker = $1 AND active = TRUE LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
async def _fetch_sector_etf_ticker(pool: asyncpg.Pool, ticker: str) -> str | None:
|
||||
"""Look up the sector ETF ticker for a company ticker."""
|
||||
row = await pool.fetchrow(_COMPANY_SECTOR_SQL, ticker)
|
||||
if row is None or row["sector"] is None:
|
||||
return None
|
||||
return SECTOR_ETF_MAP.get(row["sector"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer score computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_layer_scores(
|
||||
evidence_signals: list[dict],
|
||||
) -> tuple[float, float, float]:
|
||||
"""Compute company, macro, and competitive layer scores from evidence signals.
|
||||
|
||||
Each signal's source_type determines its layer:
|
||||
- company: news_api, filings_api, web_scrape
|
||||
- macro: macro events (source_type containing 'macro')
|
||||
- competitive: competitive signals (source_type containing 'competitive' or 'pattern')
|
||||
|
||||
Returns (score_company, score_macro, score_competitive) as fractions summing to 1.0.
|
||||
"""
|
||||
company_weight = 0.0
|
||||
macro_weight = 0.0
|
||||
competitive_weight = 0.0
|
||||
|
||||
for sig in evidence_signals:
|
||||
w = sig.get("weight", 0.0)
|
||||
source_type = sig.get("source_type", "").lower()
|
||||
catalyst_type = sig.get("catalyst_type", "").lower()
|
||||
|
||||
if "macro" in source_type or catalyst_type == "macro":
|
||||
macro_weight += w
|
||||
elif "competitive" in source_type or "pattern" in source_type:
|
||||
competitive_weight += w
|
||||
else:
|
||||
company_weight += w
|
||||
|
||||
total = company_weight + macro_weight + competitive_weight
|
||||
if total == 0.0:
|
||||
return (0.0, 0.0, 0.0)
|
||||
|
||||
return (
|
||||
round(company_weight / total, 6),
|
||||
round(macro_weight / total, 6),
|
||||
round(competitive_weight / total, 6),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL statements
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INSERT_SNAPSHOT_SQL = """
|
||||
INSERT INTO prediction_snapshots (
|
||||
id, generated_at, ticker, window, horizon, direction, action, mode,
|
||||
strength, confidence, contradiction, p_bull, p_bear,
|
||||
score_company, score_macro, score_competitive,
|
||||
evidence_count, unique_source_count, duplicate_evidence_count,
|
||||
price_at_prediction, spy_price_at_prediction, sector_etf_price_at_prediction,
|
||||
metadata
|
||||
) VALUES (
|
||||
$1::uuid, $2, $3, $4, $5, $6, $7, $8,
|
||||
$9, $10, $11, $12, $13,
|
||||
$14, $15, $16,
|
||||
$17, $18, $19,
|
||||
$20, $21, $22,
|
||||
$23::jsonb
|
||||
)
|
||||
"""
|
||||
|
||||
_INSERT_EVIDENCE_LINK_SQL = """
|
||||
INSERT INTO signal_evidence_links (
|
||||
id, prediction_id, document_id, signal_id, ticker,
|
||||
source, source_type, catalyst_type, sentiment,
|
||||
impact, extraction_confidence, weight,
|
||||
is_duplicate, canonical_evidence_key, contribution_score,
|
||||
metadata
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, $3, $4, $5,
|
||||
$6, $7, $8, $9,
|
||||
$10, $11, $12,
|
||||
$13, $14, $15,
|
||||
$16::jsonb
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point (Requirements 1.1–1.7, 2.1–2.6, 3.1–3.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def create_prediction_snapshot(
|
||||
pool: asyncpg.Pool,
|
||||
recommendation: Recommendation,
|
||||
trend_summary: TrendSummary,
|
||||
evidence_signals: list[dict],
|
||||
evidence_docs: list[dict],
|
||||
) -> PredictionSnapshot:
|
||||
"""Create and persist a prediction snapshot with evidence links.
|
||||
|
||||
Steps:
|
||||
1. Fetch current prices (ticker, SPY, sector ETF) from market_snapshots
|
||||
2. Compute canonical evidence keys and detect duplicates
|
||||
3. Clamp individual document weights to MAX_SINGLE_DOCUMENT_WEIGHT
|
||||
4. Compute contribution scores (one-vote-per-canonical-key dedup)
|
||||
5. Persist snapshot and evidence links in a transaction
|
||||
|
||||
Args:
|
||||
pool: asyncpg connection pool.
|
||||
recommendation: The generated Recommendation object.
|
||||
trend_summary: The TrendSummary used to generate the recommendation.
|
||||
evidence_signals: List of dicts with signal fields (source, source_type,
|
||||
catalyst_type, sentiment, impact, extraction_confidence, weight,
|
||||
document_id, signal_id, ticker).
|
||||
evidence_docs: List of dicts with document metadata (title, url, document_id).
|
||||
|
||||
Returns:
|
||||
The persisted PredictionSnapshot.
|
||||
"""
|
||||
ticker = recommendation.ticker
|
||||
|
||||
# 1. Fetch prices — handle NULL gracefully (Requirement 1.5)
|
||||
ticker_price = await fetch_latest_close_price(pool, ticker)
|
||||
if ticker_price is None:
|
||||
logger.warning("No market price available for %s at snapshot time", ticker)
|
||||
|
||||
spy_price = await fetch_latest_close_price(pool, "SPY")
|
||||
if spy_price is None:
|
||||
logger.warning("No SPY price available at snapshot time")
|
||||
|
||||
sector_etf_ticker = await _fetch_sector_etf_ticker(pool, ticker)
|
||||
sector_etf_price: float | None = None
|
||||
if sector_etf_ticker is not None:
|
||||
sector_etf_price = await fetch_latest_close_price(pool, sector_etf_ticker)
|
||||
if sector_etf_price is None:
|
||||
logger.warning(
|
||||
"No sector ETF price available for %s (%s) at snapshot time",
|
||||
sector_etf_ticker,
|
||||
ticker,
|
||||
)
|
||||
else:
|
||||
logger.warning("No sector ETF mapping found for ticker %s", ticker)
|
||||
|
||||
# 2. Build a doc lookup for canonical key computation
|
||||
doc_lookup: dict[str, dict] = {}
|
||||
for doc in evidence_docs:
|
||||
doc_id = doc.get("document_id", "")
|
||||
doc_lookup[doc_id] = doc
|
||||
|
||||
# 3. Process evidence signals: compute canonical keys, detect duplicates,
|
||||
# clamp weights
|
||||
processed_links: list[dict] = []
|
||||
seen_canonical_keys: dict[str, int] = {} # canonical_key -> first index
|
||||
|
||||
for sig in evidence_signals:
|
||||
doc_id = sig.get("document_id", "")
|
||||
doc_meta = doc_lookup.get(doc_id, {})
|
||||
title = doc_meta.get("title", "")
|
||||
url = doc_meta.get("url", "")
|
||||
|
||||
canonical_key = compute_canonical_evidence_key(title, url)
|
||||
|
||||
# Detect duplicates: same canonical key for same ticker
|
||||
is_duplicate = canonical_key in seen_canonical_keys
|
||||
if not is_duplicate:
|
||||
seen_canonical_keys[canonical_key] = len(processed_links)
|
||||
|
||||
# Clamp weight to MAX_SINGLE_DOCUMENT_WEIGHT (Requirement 3.3)
|
||||
raw_weight = sig.get("weight", 0.0)
|
||||
clamped_weight = min(raw_weight, MAX_SINGLE_DOCUMENT_WEIGHT)
|
||||
|
||||
processed_links.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"document_id": doc_id,
|
||||
"signal_id": sig.get("signal_id", ""),
|
||||
"ticker": sig.get("ticker", ticker),
|
||||
"source": sig.get("source", ""),
|
||||
"source_type": sig.get("source_type", ""),
|
||||
"catalyst_type": sig.get("catalyst_type", ""),
|
||||
"sentiment": sig.get("sentiment", ""),
|
||||
"impact": sig.get("impact", 0.0),
|
||||
"extraction_confidence": sig.get("extraction_confidence", 0.0),
|
||||
"weight": clamped_weight,
|
||||
"is_duplicate": is_duplicate,
|
||||
"canonical_evidence_key": canonical_key,
|
||||
})
|
||||
|
||||
# 4. Compute contribution scores — one vote per canonical key (Requirement 3.4)
|
||||
# Only non-duplicate links contribute to the weight pool
|
||||
non_dup_weights = [
|
||||
link["weight"] for link in processed_links if not link["is_duplicate"]
|
||||
]
|
||||
non_dup_scores = compute_contribution_scores(non_dup_weights)
|
||||
|
||||
# Assign contribution scores: non-duplicates get their computed score,
|
||||
# duplicates get 0.0
|
||||
score_idx = 0
|
||||
for link in processed_links:
|
||||
if not link["is_duplicate"]:
|
||||
link["contribution_score"] = non_dup_scores[score_idx]
|
||||
score_idx += 1
|
||||
else:
|
||||
link["contribution_score"] = 0.0
|
||||
|
||||
# 5. Compute deduplication quality metrics (Requirements 3.1, 3.2)
|
||||
unique_sources = {
|
||||
link["source"]
|
||||
for link in processed_links
|
||||
if not link["is_duplicate"]
|
||||
}
|
||||
unique_source_count = len(unique_sources)
|
||||
duplicate_evidence_count = sum(
|
||||
1 for link in processed_links if link["is_duplicate"]
|
||||
)
|
||||
|
||||
# 6. Compute layer scores from evidence signals
|
||||
score_company, score_macro, score_competitive = _compute_layer_scores(
|
||||
evidence_signals
|
||||
)
|
||||
|
||||
# 7. Build metadata from trend summary context (Requirement 1.7)
|
||||
metadata: dict = {}
|
||||
if trend_summary.market_context is not None:
|
||||
metadata["market_context"] = {
|
||||
"ticker": trend_summary.market_context.ticker,
|
||||
"price_change_pct": trend_summary.market_context.price_change_pct,
|
||||
"avg_volume": trend_summary.market_context.avg_volume,
|
||||
"volume_change_pct": trend_summary.market_context.volume_change_pct,
|
||||
"volatility": trend_summary.market_context.volatility,
|
||||
"latest_close": trend_summary.market_context.latest_close,
|
||||
"bars_available": trend_summary.market_context.bars_available,
|
||||
}
|
||||
if sector_etf_ticker is not None:
|
||||
metadata["sector_etf_ticker"] = sector_etf_ticker
|
||||
|
||||
# 8. Build the snapshot
|
||||
snapshot_id = str(uuid.uuid4())
|
||||
snapshot = PredictionSnapshot(
|
||||
id=snapshot_id,
|
||||
generated_at=recommendation.generated_at,
|
||||
ticker=ticker,
|
||||
window=trend_summary.window.value,
|
||||
horizon=recommendation.time_horizon,
|
||||
direction=trend_summary.trend_direction.value,
|
||||
action=recommendation.action.value,
|
||||
mode=recommendation.mode.value,
|
||||
strength=trend_summary.trend_strength,
|
||||
confidence=recommendation.confidence,
|
||||
contradiction=trend_summary.contradiction_score,
|
||||
p_bull=trend_summary.p_bull,
|
||||
p_bear=1.0 - trend_summary.p_bull if trend_summary.p_bull is not None else None,
|
||||
score_company=score_company,
|
||||
score_macro=score_macro,
|
||||
score_competitive=score_competitive,
|
||||
evidence_count=len(processed_links),
|
||||
unique_source_count=unique_source_count,
|
||||
duplicate_evidence_count=duplicate_evidence_count,
|
||||
price_at_prediction=ticker_price,
|
||||
spy_price_at_prediction=spy_price,
|
||||
sector_etf_price_at_prediction=sector_etf_price,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# 9. Build evidence link objects
|
||||
evidence_link_objects: list[SignalEvidenceLink] = []
|
||||
for link in processed_links:
|
||||
evidence_link_objects.append(
|
||||
SignalEvidenceLink(
|
||||
id=link["id"],
|
||||
prediction_id=snapshot_id,
|
||||
document_id=link["document_id"],
|
||||
signal_id=link["signal_id"],
|
||||
ticker=link["ticker"],
|
||||
source=link["source"],
|
||||
source_type=link["source_type"],
|
||||
catalyst_type=link["catalyst_type"],
|
||||
sentiment=link["sentiment"],
|
||||
impact=link["impact"],
|
||||
extraction_confidence=link["extraction_confidence"],
|
||||
weight=link["weight"],
|
||||
is_duplicate=link["is_duplicate"],
|
||||
canonical_evidence_key=link["canonical_evidence_key"],
|
||||
contribution_score=link["contribution_score"],
|
||||
)
|
||||
)
|
||||
|
||||
# 10. Persist in a transaction (Requirements 1.6, 2.6)
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
await conn.execute(
|
||||
_INSERT_SNAPSHOT_SQL,
|
||||
snapshot.id,
|
||||
snapshot.generated_at,
|
||||
snapshot.ticker,
|
||||
snapshot.window,
|
||||
snapshot.horizon,
|
||||
snapshot.direction,
|
||||
snapshot.action,
|
||||
snapshot.mode,
|
||||
snapshot.strength,
|
||||
snapshot.confidence,
|
||||
snapshot.contradiction,
|
||||
snapshot.p_bull,
|
||||
snapshot.p_bear,
|
||||
snapshot.score_company,
|
||||
snapshot.score_macro,
|
||||
snapshot.score_competitive,
|
||||
snapshot.evidence_count,
|
||||
snapshot.unique_source_count,
|
||||
snapshot.duplicate_evidence_count,
|
||||
snapshot.price_at_prediction,
|
||||
snapshot.spy_price_at_prediction,
|
||||
snapshot.sector_etf_price_at_prediction,
|
||||
json.dumps(snapshot.metadata),
|
||||
)
|
||||
|
||||
for link in evidence_link_objects:
|
||||
await conn.execute(
|
||||
_INSERT_EVIDENCE_LINK_SQL,
|
||||
link.id,
|
||||
link.prediction_id,
|
||||
link.document_id,
|
||||
link.signal_id,
|
||||
link.ticker,
|
||||
link.source,
|
||||
link.source_type,
|
||||
link.catalyst_type,
|
||||
link.sentiment,
|
||||
link.impact,
|
||||
link.extraction_confidence,
|
||||
link.weight,
|
||||
link.is_duplicate,
|
||||
link.canonical_evidence_key,
|
||||
link.contribution_score,
|
||||
json.dumps(link.metadata),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created prediction snapshot %s for %s: %d evidence links "
|
||||
"(%d unique sources, %d duplicates), prices: ticker=%s spy=%s sector_etf=%s",
|
||||
snapshot_id,
|
||||
ticker,
|
||||
len(evidence_link_objects),
|
||||
unique_source_count,
|
||||
duplicate_evidence_count,
|
||||
ticker_price,
|
||||
spy_price,
|
||||
sector_etf_price,
|
||||
)
|
||||
|
||||
return snapshot
|
||||
Reference in New Issue
Block a user