"""Model performance metrics collection and persistence. Tracks extraction success/failure rates, latency percentiles, retry counts, validation error distributions, confidence scores, and token usage estimates. Metrics are persisted to PostgreSQL for operational dashboards and published to the analytical lake for Trino/Superset queries. Requirements: 5.2, 5.4, 12.1, 12.2 """ from __future__ import annotations import json import logging from dataclasses import dataclass, field from datetime import datetime, timezone import asyncpg from services.extractor.client import ExtractionResponse logger = logging.getLogger("extractor_metrics") # Rough token estimate: ~4 chars per token for English text _CHARS_PER_TOKEN = 4 @dataclass class ExtractionMetrics: """Metrics extracted from a single extraction run.""" document_id: str = "" ticker: str = "" model_name: str = "" prompt_version: str = "" schema_version: str = "" success: bool = False attempt_count: int = 0 total_duration_ms: int = 0 first_attempt_duration_ms: int = 0 final_attempt_duration_ms: int = 0 confidence: float = 0.0 validation_status: str = "unknown" validation_error_count: int = 0 validation_warning_count: int = 0 validation_errors: list[str] = field(default_factory=list) retry_count: int = 0 input_token_estimate: int = 0 output_token_estimate: int = 0 company_count: int = 0 recorded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) def collect_metrics( extraction_response: ExtractionResponse, *, document_id: str = "", ticker: str = "", document_text_length: int = 0, ) -> ExtractionMetrics: """Collect metrics from an ExtractionResponse. Args: extraction_response: The full response from OllamaClient.extract(). document_id: UUID of the source document. ticker: Primary ticker symbol. document_text_length: Length of the input document text in characters. Returns: An ExtractionMetrics dataclass with all computed fields. """ attempts = extraction_response.attempts first_dur = attempts[0].duration_ms if attempts else 0 final_dur = attempts[-1].duration_ms if attempts else 0 # Gather validation info from the final attempt final_attempt = attempts[-1] if attempts else None val_errors: list[str] = [] val_warnings: list[str] = [] if final_attempt and final_attempt.validation: val_errors = final_attempt.validation.errors val_warnings = final_attempt.validation.warnings # Determine validation status if extraction_response.success: validation_status = "valid" elif attempts: validation_status = "failed" else: validation_status = "unknown" # Confidence from the result, or 0 if failed confidence = 0.0 company_count = 0 if extraction_response.result: confidence = extraction_response.result.confidence company_count = len(extraction_response.result.companies) # Token estimates input_tokens = document_text_length // _CHARS_PER_TOKEN if document_text_length > 0 else 0 output_tokens = 0 if final_attempt and final_attempt.raw_output: output_tokens = len(final_attempt.raw_output) // _CHARS_PER_TOKEN return ExtractionMetrics( document_id=document_id, ticker=ticker, model_name=extraction_response.model, prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""), schema_version=extraction_response.prompt_metadata.get("schema_version", ""), success=extraction_response.success, attempt_count=len(attempts), total_duration_ms=extraction_response.total_duration_ms, first_attempt_duration_ms=first_dur, final_attempt_duration_ms=final_dur, confidence=confidence, validation_status=validation_status, validation_error_count=len(val_errors), validation_warning_count=len(val_warnings), validation_errors=val_errors, retry_count=max(0, len(attempts) - 1), input_token_estimate=input_tokens, output_token_estimate=output_tokens, company_count=company_count, ) async def persist_metrics( pool: asyncpg.Pool, metrics: ExtractionMetrics, ) -> str: """Persist extraction metrics to the model_performance_metrics table. Args: pool: PostgreSQL connection pool. metrics: Collected metrics from an extraction run. Returns: The UUID of the inserted metrics row. """ row_id = await pool.fetchval( """INSERT INTO model_performance_metrics (document_id, ticker, model_name, prompt_version, schema_version, success, attempt_count, total_duration_ms, first_attempt_duration_ms, final_attempt_duration_ms, confidence, validation_status, validation_error_count, validation_warning_count, validation_errors, retry_count, input_token_estimate, output_token_estimate, company_count, recorded_at) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15::jsonb, $16, $17, $18, $19, $20) RETURNING id""", metrics.document_id, metrics.ticker, metrics.model_name, metrics.prompt_version, metrics.schema_version, metrics.success, metrics.attempt_count, metrics.total_duration_ms, metrics.first_attempt_duration_ms, metrics.final_attempt_duration_ms, metrics.confidence, metrics.validation_status, metrics.validation_error_count, metrics.validation_warning_count, json.dumps(metrics.validation_errors), metrics.retry_count, metrics.input_token_estimate, metrics.output_token_estimate, metrics.company_count, metrics.recorded_at, ) logger.info( "Persisted extraction metrics %s for doc %s: success=%s duration=%dms retries=%d", row_id, metrics.document_id, metrics.success, metrics.total_duration_ms, metrics.retry_count, ) return str(row_id) async def get_model_performance_summary( pool: asyncpg.Pool, *, model_name: str | None = None, hours: int = 24, ) -> dict[str, object]: """Query aggregated model performance metrics for dashboards. Returns a summary dict with success rate, avg latency, retry rate, confidence distribution, and error breakdown for the given time window. Args: pool: PostgreSQL connection pool. model_name: Optional filter by model name. hours: Lookback window in hours (default 24). Returns: Dict with aggregated performance metrics. """ model_filter = "AND model_name = $2" if model_name else "" params: list[object] = [hours] if model_name: params.append(model_name) row = await pool.fetchrow( f"""SELECT COUNT(*) AS total_extractions, COUNT(*) FILTER (WHERE success) AS successful, COUNT(*) FILTER (WHERE NOT success) AS failed, ROUND(AVG(total_duration_ms)::numeric, 1) AS avg_duration_ms, ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p50_duration_ms, ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p95_duration_ms, ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p99_duration_ms, ROUND(AVG(retry_count)::numeric, 2) AS avg_retries, ROUND(AVG(confidence)::numeric, 3) AS avg_confidence, SUM(input_token_estimate) AS total_input_tokens, SUM(output_token_estimate) AS total_output_tokens, ROUND(AVG(company_count)::numeric, 2) AS avg_companies_per_doc, ROUND(AVG(validation_error_count)::numeric, 2) AS avg_validation_errors, ROUND(AVG(validation_warning_count)::numeric, 2) AS avg_validation_warnings FROM model_performance_metrics WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1 {model_filter}""", *params, ) if not row or row["total_extractions"] == 0: return {"total_extractions": 0, "success_rate": 0.0} total = row["total_extractions"] successful = row["successful"] return { "total_extractions": total, "successful": successful, "failed": row["failed"], "success_rate": round(successful / total, 4) if total > 0 else 0.0, "avg_duration_ms": float(row["avg_duration_ms"] or 0), "p50_duration_ms": float(row["p50_duration_ms"] or 0), "p95_duration_ms": float(row["p95_duration_ms"] or 0), "p99_duration_ms": float(row["p99_duration_ms"] or 0), "avg_retries": float(row["avg_retries"] or 0), "avg_confidence": float(row["avg_confidence"] or 0), "total_input_tokens": int(row["total_input_tokens"] or 0), "total_output_tokens": int(row["total_output_tokens"] or 0), "avg_companies_per_doc": float(row["avg_companies_per_doc"] or 0), "avg_validation_errors": float(row["avg_validation_errors"] or 0), "avg_validation_warnings": float(row["avg_validation_warnings"] or 0), "hours": hours, }