Files
stonks-oracle/services/extractor/metrics.py
T

251 lines
9.3 KiB
Python

"""Model performance metrics collection and persistence.
Tracks extraction success/failure rates, latency percentiles, retry counts,
validation error distributions, confidence scores, and token usage estimates.
Metrics are persisted to PostgreSQL for operational dashboards and published
to the analytical lake for Trino/Superset queries.
Requirements: 5.2, 5.4, 12.1, 12.2
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.extractor.client import ExtractionResponse
logger = logging.getLogger("extractor_metrics")
# Rough token estimate: ~4 chars per token for English text
_CHARS_PER_TOKEN = 4
@dataclass
class ExtractionMetrics:
"""Metrics extracted from a single extraction run."""
document_id: str = ""
ticker: str = ""
model_name: str = ""
prompt_version: str = ""
schema_version: str = ""
success: bool = False
attempt_count: int = 0
total_duration_ms: int = 0
first_attempt_duration_ms: int = 0
final_attempt_duration_ms: int = 0
confidence: float = 0.0
validation_status: str = "unknown"
validation_error_count: int = 0
validation_warning_count: int = 0
validation_errors: list[str] = field(default_factory=list)
retry_count: int = 0
input_token_estimate: int = 0
output_token_estimate: int = 0
company_count: int = 0
recorded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def collect_metrics(
extraction_response: ExtractionResponse,
*,
document_id: str = "",
ticker: str = "",
document_text_length: int = 0,
) -> ExtractionMetrics:
"""Collect metrics from an ExtractionResponse.
Args:
extraction_response: The full response from OllamaClient.extract().
document_id: UUID of the source document.
ticker: Primary ticker symbol.
document_text_length: Length of the input document text in characters.
Returns:
An ExtractionMetrics dataclass with all computed fields.
"""
attempts = extraction_response.attempts
first_dur = attempts[0].duration_ms if attempts else 0
final_dur = attempts[-1].duration_ms if attempts else 0
# Gather validation info from the final attempt
final_attempt = attempts[-1] if attempts else None
val_errors: list[str] = []
val_warnings: list[str] = []
if final_attempt and final_attempt.validation:
val_errors = final_attempt.validation.errors
val_warnings = final_attempt.validation.warnings
# Determine validation status
if extraction_response.success:
validation_status = "valid"
elif attempts:
validation_status = "failed"
else:
validation_status = "unknown"
# Confidence from the result, or 0 if failed
confidence = 0.0
company_count = 0
if extraction_response.result:
confidence = extraction_response.result.confidence
company_count = len(extraction_response.result.companies)
# Token estimates
input_tokens = document_text_length // _CHARS_PER_TOKEN if document_text_length > 0 else 0
output_tokens = 0
if final_attempt and final_attempt.raw_output:
output_tokens = len(final_attempt.raw_output) // _CHARS_PER_TOKEN
return ExtractionMetrics(
document_id=document_id,
ticker=ticker,
model_name=extraction_response.model,
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
success=extraction_response.success,
attempt_count=len(attempts),
total_duration_ms=extraction_response.total_duration_ms,
first_attempt_duration_ms=first_dur,
final_attempt_duration_ms=final_dur,
confidence=confidence,
validation_status=validation_status,
validation_error_count=len(val_errors),
validation_warning_count=len(val_warnings),
validation_errors=val_errors,
retry_count=max(0, len(attempts) - 1),
input_token_estimate=input_tokens,
output_token_estimate=output_tokens,
company_count=company_count,
)
async def persist_metrics(
pool: asyncpg.Pool,
metrics: ExtractionMetrics,
) -> str:
"""Persist extraction metrics to the model_performance_metrics table.
Args:
pool: PostgreSQL connection pool.
metrics: Collected metrics from an extraction run.
Returns:
The UUID of the inserted metrics row.
"""
row_id = await pool.fetchval(
"""INSERT INTO model_performance_metrics
(document_id, ticker, model_name, prompt_version, schema_version,
success, attempt_count, total_duration_ms,
first_attempt_duration_ms, final_attempt_duration_ms,
confidence, validation_status, validation_error_count,
validation_warning_count, validation_errors, retry_count,
input_token_estimate, output_token_estimate, company_count,
recorded_at)
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15::jsonb, $16, $17, $18, $19, $20)
RETURNING id""",
metrics.document_id,
metrics.ticker,
metrics.model_name,
metrics.prompt_version,
metrics.schema_version,
metrics.success,
metrics.attempt_count,
metrics.total_duration_ms,
metrics.first_attempt_duration_ms,
metrics.final_attempt_duration_ms,
metrics.confidence,
metrics.validation_status,
metrics.validation_error_count,
metrics.validation_warning_count,
json.dumps(metrics.validation_errors),
metrics.retry_count,
metrics.input_token_estimate,
metrics.output_token_estimate,
metrics.company_count,
metrics.recorded_at,
)
logger.info(
"Persisted extraction metrics %s for doc %s: success=%s duration=%dms retries=%d",
row_id, metrics.document_id, metrics.success,
metrics.total_duration_ms, metrics.retry_count,
)
return str(row_id)
async def get_model_performance_summary(
pool: asyncpg.Pool,
*,
model_name: str | None = None,
hours: int = 24,
) -> dict[str, object]:
"""Query aggregated model performance metrics for dashboards.
Returns a summary dict with success rate, avg latency, retry rate,
confidence distribution, and error breakdown for the given time window.
Args:
pool: PostgreSQL connection pool.
model_name: Optional filter by model name.
hours: Lookback window in hours (default 24).
Returns:
Dict with aggregated performance metrics.
"""
model_filter = "AND model_name = $2" if model_name else ""
params: list[object] = [hours]
if model_name:
params.append(model_name)
row = await pool.fetchrow(
f"""SELECT
COUNT(*) AS total_extractions,
COUNT(*) FILTER (WHERE success) AS successful,
COUNT(*) FILTER (WHERE NOT success) AS failed,
ROUND(AVG(total_duration_ms)::numeric, 1) AS avg_duration_ms,
ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p50_duration_ms,
ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p95_duration_ms,
ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p99_duration_ms,
ROUND(AVG(retry_count)::numeric, 2) AS avg_retries,
ROUND(AVG(confidence)::numeric, 3) AS avg_confidence,
SUM(input_token_estimate) AS total_input_tokens,
SUM(output_token_estimate) AS total_output_tokens,
ROUND(AVG(company_count)::numeric, 2) AS avg_companies_per_doc,
ROUND(AVG(validation_error_count)::numeric, 2) AS avg_validation_errors,
ROUND(AVG(validation_warning_count)::numeric, 2) AS avg_validation_warnings
FROM model_performance_metrics
WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1
{model_filter}""",
*params,
)
if not row or row["total_extractions"] == 0:
return {"total_extractions": 0, "success_rate": 0.0}
total = row["total_extractions"]
successful = row["successful"]
return {
"total_extractions": total,
"successful": successful,
"failed": row["failed"],
"success_rate": round(successful / total, 4) if total > 0 else 0.0,
"avg_duration_ms": float(row["avg_duration_ms"] or 0),
"p50_duration_ms": float(row["p50_duration_ms"] or 0),
"p95_duration_ms": float(row["p95_duration_ms"] or 0),
"p99_duration_ms": float(row["p99_duration_ms"] or 0),
"avg_retries": float(row["avg_retries"] or 0),
"avg_confidence": float(row["avg_confidence"] or 0),
"total_input_tokens": int(row["total_input_tokens"] or 0),
"total_output_tokens": int(row["total_output_tokens"] or 0),
"avg_companies_per_doc": float(row["avg_companies_per_doc"] or 0),
"avg_validation_errors": float(row["avg_validation_errors"] or 0),
"avg_validation_warnings": float(row["avg_validation_warnings"] or 0),
"hours": hours,
}