251 lines
9.3 KiB
Python
251 lines
9.3 KiB
Python
"""Model performance metrics collection and persistence.
|
|
|
|
Tracks extraction success/failure rates, latency percentiles, retry counts,
|
|
validation error distributions, confidence scores, and token usage estimates.
|
|
Metrics are persisted to PostgreSQL for operational dashboards and published
|
|
to the analytical lake for Trino/Superset queries.
|
|
|
|
Requirements: 5.2, 5.4, 12.1, 12.2
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
|
|
import asyncpg
|
|
|
|
from services.extractor.client import ExtractionResponse
|
|
|
|
logger = logging.getLogger("extractor_metrics")
|
|
|
|
# Rough token estimate: ~4 chars per token for English text
|
|
_CHARS_PER_TOKEN = 4
|
|
|
|
|
|
@dataclass
|
|
class ExtractionMetrics:
|
|
"""Metrics extracted from a single extraction run."""
|
|
|
|
document_id: str = ""
|
|
ticker: str = ""
|
|
model_name: str = ""
|
|
prompt_version: str = ""
|
|
schema_version: str = ""
|
|
success: bool = False
|
|
attempt_count: int = 0
|
|
total_duration_ms: int = 0
|
|
first_attempt_duration_ms: int = 0
|
|
final_attempt_duration_ms: int = 0
|
|
confidence: float = 0.0
|
|
validation_status: str = "unknown"
|
|
validation_error_count: int = 0
|
|
validation_warning_count: int = 0
|
|
validation_errors: list[str] = field(default_factory=list)
|
|
retry_count: int = 0
|
|
input_token_estimate: int = 0
|
|
output_token_estimate: int = 0
|
|
company_count: int = 0
|
|
recorded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
|
|
def collect_metrics(
|
|
extraction_response: ExtractionResponse,
|
|
*,
|
|
document_id: str = "",
|
|
ticker: str = "",
|
|
document_text_length: int = 0,
|
|
) -> ExtractionMetrics:
|
|
"""Collect metrics from an ExtractionResponse.
|
|
|
|
Args:
|
|
extraction_response: The full response from OllamaClient.extract().
|
|
document_id: UUID of the source document.
|
|
ticker: Primary ticker symbol.
|
|
document_text_length: Length of the input document text in characters.
|
|
|
|
Returns:
|
|
An ExtractionMetrics dataclass with all computed fields.
|
|
"""
|
|
attempts = extraction_response.attempts
|
|
first_dur = attempts[0].duration_ms if attempts else 0
|
|
final_dur = attempts[-1].duration_ms if attempts else 0
|
|
|
|
# Gather validation info from the final attempt
|
|
final_attempt = attempts[-1] if attempts else None
|
|
val_errors: list[str] = []
|
|
val_warnings: list[str] = []
|
|
if final_attempt and final_attempt.validation:
|
|
val_errors = final_attempt.validation.errors
|
|
val_warnings = final_attempt.validation.warnings
|
|
|
|
# Determine validation status
|
|
if extraction_response.success:
|
|
validation_status = "valid"
|
|
elif attempts:
|
|
validation_status = "failed"
|
|
else:
|
|
validation_status = "unknown"
|
|
|
|
# Confidence from the result, or 0 if failed
|
|
confidence = 0.0
|
|
company_count = 0
|
|
if extraction_response.result:
|
|
confidence = extraction_response.result.confidence
|
|
company_count = len(extraction_response.result.companies)
|
|
|
|
# Token estimates
|
|
input_tokens = document_text_length // _CHARS_PER_TOKEN if document_text_length > 0 else 0
|
|
output_tokens = 0
|
|
if final_attempt and final_attempt.raw_output:
|
|
output_tokens = len(final_attempt.raw_output) // _CHARS_PER_TOKEN
|
|
|
|
return ExtractionMetrics(
|
|
document_id=document_id,
|
|
ticker=ticker,
|
|
model_name=extraction_response.model,
|
|
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
|
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
|
success=extraction_response.success,
|
|
attempt_count=len(attempts),
|
|
total_duration_ms=extraction_response.total_duration_ms,
|
|
first_attempt_duration_ms=first_dur,
|
|
final_attempt_duration_ms=final_dur,
|
|
confidence=confidence,
|
|
validation_status=validation_status,
|
|
validation_error_count=len(val_errors),
|
|
validation_warning_count=len(val_warnings),
|
|
validation_errors=val_errors,
|
|
retry_count=max(0, len(attempts) - 1),
|
|
input_token_estimate=input_tokens,
|
|
output_token_estimate=output_tokens,
|
|
company_count=company_count,
|
|
)
|
|
|
|
|
|
async def persist_metrics(
|
|
pool: asyncpg.Pool,
|
|
metrics: ExtractionMetrics,
|
|
) -> str:
|
|
"""Persist extraction metrics to the model_performance_metrics table.
|
|
|
|
Args:
|
|
pool: PostgreSQL connection pool.
|
|
metrics: Collected metrics from an extraction run.
|
|
|
|
Returns:
|
|
The UUID of the inserted metrics row.
|
|
"""
|
|
row_id = await pool.fetchval(
|
|
"""INSERT INTO model_performance_metrics
|
|
(document_id, ticker, model_name, prompt_version, schema_version,
|
|
success, attempt_count, total_duration_ms,
|
|
first_attempt_duration_ms, final_attempt_duration_ms,
|
|
confidence, validation_status, validation_error_count,
|
|
validation_warning_count, validation_errors, retry_count,
|
|
input_token_estimate, output_token_estimate, company_count,
|
|
recorded_at)
|
|
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
|
$11, $12, $13, $14, $15::jsonb, $16, $17, $18, $19, $20)
|
|
RETURNING id""",
|
|
metrics.document_id,
|
|
metrics.ticker,
|
|
metrics.model_name,
|
|
metrics.prompt_version,
|
|
metrics.schema_version,
|
|
metrics.success,
|
|
metrics.attempt_count,
|
|
metrics.total_duration_ms,
|
|
metrics.first_attempt_duration_ms,
|
|
metrics.final_attempt_duration_ms,
|
|
metrics.confidence,
|
|
metrics.validation_status,
|
|
metrics.validation_error_count,
|
|
metrics.validation_warning_count,
|
|
json.dumps(metrics.validation_errors),
|
|
metrics.retry_count,
|
|
metrics.input_token_estimate,
|
|
metrics.output_token_estimate,
|
|
metrics.company_count,
|
|
metrics.recorded_at,
|
|
)
|
|
logger.info(
|
|
"Persisted extraction metrics %s for doc %s: success=%s duration=%dms retries=%d",
|
|
row_id, metrics.document_id, metrics.success,
|
|
metrics.total_duration_ms, metrics.retry_count,
|
|
)
|
|
return str(row_id)
|
|
|
|
|
|
async def get_model_performance_summary(
|
|
pool: asyncpg.Pool,
|
|
*,
|
|
model_name: str | None = None,
|
|
hours: int = 24,
|
|
) -> dict[str, object]:
|
|
"""Query aggregated model performance metrics for dashboards.
|
|
|
|
Returns a summary dict with success rate, avg latency, retry rate,
|
|
confidence distribution, and error breakdown for the given time window.
|
|
|
|
Args:
|
|
pool: PostgreSQL connection pool.
|
|
model_name: Optional filter by model name.
|
|
hours: Lookback window in hours (default 24).
|
|
|
|
Returns:
|
|
Dict with aggregated performance metrics.
|
|
"""
|
|
model_filter = "AND model_name = $2" if model_name else ""
|
|
params: list[object] = [hours]
|
|
if model_name:
|
|
params.append(model_name)
|
|
|
|
row = await pool.fetchrow(
|
|
f"""SELECT
|
|
COUNT(*) AS total_extractions,
|
|
COUNT(*) FILTER (WHERE success) AS successful,
|
|
COUNT(*) FILTER (WHERE NOT success) AS failed,
|
|
ROUND(AVG(total_duration_ms)::numeric, 1) AS avg_duration_ms,
|
|
ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p50_duration_ms,
|
|
ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p95_duration_ms,
|
|
ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p99_duration_ms,
|
|
ROUND(AVG(retry_count)::numeric, 2) AS avg_retries,
|
|
ROUND(AVG(confidence)::numeric, 3) AS avg_confidence,
|
|
SUM(input_token_estimate) AS total_input_tokens,
|
|
SUM(output_token_estimate) AS total_output_tokens,
|
|
ROUND(AVG(company_count)::numeric, 2) AS avg_companies_per_doc,
|
|
ROUND(AVG(validation_error_count)::numeric, 2) AS avg_validation_errors,
|
|
ROUND(AVG(validation_warning_count)::numeric, 2) AS avg_validation_warnings
|
|
FROM model_performance_metrics
|
|
WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1
|
|
{model_filter}""",
|
|
*params,
|
|
)
|
|
|
|
if not row or row["total_extractions"] == 0:
|
|
return {"total_extractions": 0, "success_rate": 0.0}
|
|
|
|
total = row["total_extractions"]
|
|
successful = row["successful"]
|
|
|
|
return {
|
|
"total_extractions": total,
|
|
"successful": successful,
|
|
"failed": row["failed"],
|
|
"success_rate": round(successful / total, 4) if total > 0 else 0.0,
|
|
"avg_duration_ms": float(row["avg_duration_ms"] or 0),
|
|
"p50_duration_ms": float(row["p50_duration_ms"] or 0),
|
|
"p95_duration_ms": float(row["p95_duration_ms"] or 0),
|
|
"p99_duration_ms": float(row["p99_duration_ms"] or 0),
|
|
"avg_retries": float(row["avg_retries"] or 0),
|
|
"avg_confidence": float(row["avg_confidence"] or 0),
|
|
"total_input_tokens": int(row["total_input_tokens"] or 0),
|
|
"total_output_tokens": int(row["total_output_tokens"] or 0),
|
|
"avg_companies_per_doc": float(row["avg_companies_per_doc"] or 0),
|
|
"avg_validation_errors": float(row["avg_validation_errors"] or 0),
|
|
"avg_validation_warnings": float(row["avg_validation_warnings"] or 0),
|
|
"hours": hours,
|
|
}
|