"""Source accuracy tracker for historical prediction accuracy per source. Tracks per-source accuracy metrics (fraction of correct directional calls) used by the probabilistic scoring pipeline to weight source credibility. Accuracy data is stored in the ``source_accuracy`` database table and fetched in batch at the start of each aggregation cycle. Requirements: 4.1, 4.2, 4.3, 4.4, 4.5 """ from __future__ import annotations import logging from dataclasses import dataclass from datetime import datetime, timezone import asyncpg logger = logging.getLogger(__name__) @dataclass class SourceAccuracy: """Per-source historical prediction accuracy. Attributes: source_id: Unique identifier for the signal source. accuracy_ratio: Fraction of correct directional calls, in [0, 1]. sample_count: Number of signals with known outcomes. last_updated: Timestamp of the most recent accuracy update. """ source_id: str accuracy_ratio: float sample_count: int last_updated: datetime @property def accuracy_factor(self) -> float: """Multiplicative factor for credibility weight. Returns 1.0 (neutral) when sample_count < 10. Otherwise scales linearly from 0.5 (0% accuracy) to 1.5 (100% accuracy). Corrupted accuracy_ratio values outside [0, 1] are clamped before computing the factor. """ if self.sample_count < 10: return 1.0 clamped = max(0.0, min(1.0, self.accuracy_ratio)) return 0.5 + clamped async def fetch_source_accuracy( pool: asyncpg.Pool, source_ids: list[str], ) -> dict[str, SourceAccuracy]: """Fetch accuracy metrics for a batch of sources. Queries the ``source_accuracy`` table for all requested *source_ids* in a single round-trip. Returns a mapping from source_id to its :class:`SourceAccuracy` record. When the database is unreachable or the query fails, returns an empty dict so that callers fall back to the neutral accuracy factor of 1.0. """ if not source_ids: return {} try: rows = await pool.fetch( """ SELECT source_id, accuracy_ratio, sample_count, last_updated FROM source_accuracy WHERE source_id = ANY($1::varchar[]) """, source_ids, ) except Exception: logger.warning( "Failed to fetch source accuracy; defaulting to neutral factor", exc_info=True, ) return {} result: dict[str, SourceAccuracy] = {} for row in rows: sid = row["source_id"] ratio = row["accuracy_ratio"] # Clamp corrupted accuracy_ratio to [0.0, 1.0] ratio = max(0.0, min(1.0, float(ratio))) result[sid] = SourceAccuracy( source_id=sid, accuracy_ratio=ratio, sample_count=int(row["sample_count"]), last_updated=row["last_updated"], ) return result async def update_source_accuracy( pool: asyncpg.Pool, source_id: str, realized_outcomes: list[tuple[str, float]], ) -> None: """Update accuracy metrics for a source from realized price outcomes. Each element of *realized_outcomes* is a ``(predicted_direction, actual_7d_return)`` pair. A prediction is considered correct when: * ``predicted_direction`` is ``"bullish"`` and ``actual_7d_return > 0`` * ``predicted_direction`` is ``"bearish"`` and ``actual_7d_return < 0`` Neutral predictions and zero returns are excluded from the accuracy calculation. The function upserts the ``source_accuracy`` row, merging the new outcomes with any existing sample count and accuracy ratio. """ if not realized_outcomes: return # Count correct directional calls from the new outcomes. correct = 0 total = 0 for predicted_direction, actual_return in realized_outcomes: direction = predicted_direction.lower() if direction not in ("bullish", "bearish"): continue if actual_return == 0.0: continue total += 1 if direction == "bullish" and actual_return > 0: correct += 1 elif direction == "bearish" and actual_return < 0: correct += 1 if total == 0: return now = datetime.now(timezone.utc) try: await pool.execute( """ INSERT INTO source_accuracy (source_id, accuracy_ratio, sample_count, last_updated) VALUES ($1, $2, $3, $4) ON CONFLICT (source_id) DO UPDATE SET accuracy_ratio = ( source_accuracy.accuracy_ratio * source_accuracy.sample_count + $2 * $3 ) / NULLIF(source_accuracy.sample_count + $3, 0), sample_count = source_accuracy.sample_count + $3, last_updated = $4 """, source_id, correct / total, total, now, ) except Exception: logger.warning( "Failed to update source accuracy for %s; continuing with stale data", source_id, exc_info=True, )