"""Calibration Engine — Bayesian shrinkage source reliability and weight adjustment. Computes source reliability scores using Bayesian shrinkage from historical prediction outcomes, and adjusts evidence weights based on source performance. Updates the existing source_accuracy table with reliability scores. Requirements: 8.1, 8.2, 8.3, 8.4, 8.5 """ from __future__ import annotations import logging import asyncpg logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Pure functions — testable without a database # --------------------------------------------------------------------------- def compute_source_reliability( observed_win_rate: float, sample_count: int, prior_strength: int = 30, ) -> float: """Bayesian shrinkage source reliability. reliability = 0.5 + (n / (n + prior_strength)) * (observed_win_rate - 0.5) Returns value in [0.0, 1.0]. When n=0, returns 0.5 (prior mean). As nā†’āˆž, approaches observed_win_rate. """ if sample_count <= 0: return 0.5 shrinkage = sample_count / (sample_count + prior_strength) reliability = 0.5 + shrinkage * (observed_win_rate - 0.5) # Clamp to [0.0, 1.0] for safety (should already be in range when # observed_win_rate is in [0.0, 1.0], but guard against edge cases). return max(0.0, min(1.0, reliability)) def compute_adjusted_evidence_weight( base_weight: float, reliability: float, ) -> float: """Adjusted weight = base_weight * (0.5 + reliability), clamped to [0.1, 2.0].""" adjusted = base_weight * (0.5 + reliability) return max(0.1, min(2.0, adjusted)) # --------------------------------------------------------------------------- # SQL queries # --------------------------------------------------------------------------- # Query v_source_performance to get per-source win rates and sample counts. # Groups by source, counting total predictions and directional wins. _SOURCE_PERFORMANCE_SQL = """ SELECT source, COUNT(*) AS sample_count, COUNT(*) FILTER (WHERE direction_correct = TRUE) AS win_count FROM v_source_performance WHERE direction_correct IS NOT NULL GROUP BY source """ # Upsert into source_accuracy: update accuracy_ratio and sample_count # for existing sources, insert new ones. _UPSERT_SOURCE_ACCURACY_SQL = """ INSERT INTO source_accuracy (source_id, accuracy_ratio, sample_count, last_updated) VALUES ($1, $2, $3, NOW()) ON CONFLICT (source_id) DO UPDATE SET accuracy_ratio = EXCLUDED.accuracy_ratio, sample_count = EXCLUDED.sample_count, last_updated = NOW() """ # --------------------------------------------------------------------------- # Database-backed function # --------------------------------------------------------------------------- async def update_source_reliabilities( pool: asyncpg.Pool, ) -> int: """Recompute and store source reliability scores from latest outcomes. 1. Queries v_source_performance to get per-source win rates and counts 2. Computes Bayesian shrinkage reliability for each source 3. Upserts into source_accuracy table (accuracy_ratio = reliability) Returns count of sources updated. """ try: rows = await pool.fetch(_SOURCE_PERFORMANCE_SQL) except Exception: logger.exception("Failed to query source performance for reliability update") return 0 if not rows: logger.info("No source performance data available for reliability update") return 0 updated = 0 for row in rows: source = row["source"] sample_count = row["sample_count"] win_count = row["win_count"] observed_win_rate = win_count / sample_count if sample_count > 0 else 0.5 reliability = compute_source_reliability(observed_win_rate, sample_count) try: await pool.execute( _UPSERT_SOURCE_ACCURACY_SQL, source, reliability, sample_count, ) updated += 1 except Exception: logger.exception( "Failed to upsert source reliability for source=%s", source ) logger.info("Updated source reliabilities for %d sources", updated) return updated