phase 14-15: docker build validation and helm deployment
This commit is contained in:
@@ -0,0 +1,268 @@
|
||||
"""Ollama client wrapper using structured output format.
|
||||
|
||||
Sends documents to a local Ollama instance via the /api/chat endpoint
|
||||
with the ``format`` parameter set to the extraction JSON schema, ensuring
|
||||
the model returns schema-compliant JSON.
|
||||
|
||||
Includes retry logic for invalid or incomplete model responses with
|
||||
exponential backoff, error classification, and full audit preservation.
|
||||
|
||||
Requirements: 5.1, 5.2, 5.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
from services.extractor.prompts import (
|
||||
build_extraction_prompt,
|
||||
get_json_schema,
|
||||
get_prompt_metadata,
|
||||
)
|
||||
from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction
|
||||
from services.shared.config import OllamaConfig
|
||||
|
||||
logger = logging.getLogger("ollama_client")
|
||||
|
||||
# Errors that should NOT be retried — the request itself is bad.
|
||||
_NON_RETRYABLE_ERRORS = frozenset({
|
||||
"http_400",
|
||||
"http_401",
|
||||
"http_403",
|
||||
"http_404",
|
||||
"http_422",
|
||||
})
|
||||
|
||||
|
||||
def _is_retryable(error: str | None) -> bool:
|
||||
"""Determine whether an extraction error warrants a retry."""
|
||||
if error is None:
|
||||
return False
|
||||
return error not in _NON_RETRYABLE_ERRORS
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionAttempt:
|
||||
"""Record of a single extraction attempt for audit."""
|
||||
|
||||
raw_output: str = ""
|
||||
validation: ValidationReport | None = None
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
model: str = ""
|
||||
retryable: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResponse:
|
||||
"""Full response from an extraction call, including all attempts."""
|
||||
|
||||
success: bool = False
|
||||
result: ExtractionResult | None = None
|
||||
attempts: list[ExtractionAttempt] = field(default_factory=list)
|
||||
prompt_metadata: dict[str, str] = field(default_factory=dict)
|
||||
model: str = ""
|
||||
total_duration_ms: int = 0
|
||||
|
||||
|
||||
def _compute_backoff(
|
||||
attempt_num: int,
|
||||
base_delay: float,
|
||||
max_delay: float,
|
||||
multiplier: float,
|
||||
) -> float:
|
||||
"""Compute exponential backoff delay for a given attempt number."""
|
||||
delay = base_delay * (multiplier ** attempt_num)
|
||||
return min(delay, max_delay)
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Async client for Ollama structured extraction.
|
||||
|
||||
Usage::
|
||||
|
||||
config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b")
|
||||
client = OllamaClient(config)
|
||||
response = await client.extract(
|
||||
document_text="Apple reported record earnings...",
|
||||
document_type="article",
|
||||
document_id="abc-123",
|
||||
)
|
||||
if response.success:
|
||||
print(response.result)
|
||||
"""
|
||||
|
||||
_config: OllamaConfig
|
||||
_max_retries: int
|
||||
_base_delay: float
|
||||
_max_delay: float
|
||||
_backoff_multiplier: float
|
||||
_owns_client: bool
|
||||
_http: httpx.AsyncClient
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OllamaConfig,
|
||||
max_retries: int | None = None,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._max_retries = max_retries if max_retries is not None else config.max_retries
|
||||
self._base_delay = config.retry_base_delay
|
||||
self._max_delay = config.retry_max_delay
|
||||
self._backoff_multiplier = config.retry_backoff_multiplier
|
||||
self._owns_client = http_client is None
|
||||
self._http = http_client or httpx.AsyncClient(timeout=config.timeout)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying HTTP client if we own it."""
|
||||
if self._owns_client:
|
||||
await self._http.aclose()
|
||||
|
||||
async def extract(
|
||||
self,
|
||||
document_text: str,
|
||||
document_type: str = "article",
|
||||
document_id: str = "",
|
||||
known_tickers: list[str] | None = None,
|
||||
) -> ExtractionResponse:
|
||||
"""Send a document to Ollama for structured intelligence extraction.
|
||||
|
||||
Retries up to ``max_retries`` times when the model returns invalid
|
||||
or incomplete JSON. Uses exponential backoff between retries.
|
||||
Non-retryable errors (e.g. HTTP 400) stop retries immediately.
|
||||
Each attempt and its validation result are preserved for audit.
|
||||
|
||||
Args:
|
||||
document_text: Normalized text content of the document.
|
||||
document_type: One of article, filing, transcript, press_release.
|
||||
document_id: Optional document ID for traceability.
|
||||
known_tickers: Optional ticker hints for the model.
|
||||
|
||||
Returns:
|
||||
An ``ExtractionResponse`` with the parsed result on success.
|
||||
"""
|
||||
prompts = build_extraction_prompt(
|
||||
document_text=document_text,
|
||||
document_type=document_type,
|
||||
document_id=document_id,
|
||||
known_tickers=known_tickers,
|
||||
)
|
||||
json_schema = get_json_schema()
|
||||
prompt_meta = get_prompt_metadata()
|
||||
|
||||
response = ExtractionResponse(
|
||||
prompt_metadata=prompt_meta,
|
||||
model=self._config.model,
|
||||
)
|
||||
|
||||
total_start = time.monotonic()
|
||||
|
||||
for attempt_num in range(self._max_retries + 1):
|
||||
attempt = await self._call_ollama(prompts, json_schema, document_text)
|
||||
response.attempts.append(attempt)
|
||||
|
||||
if attempt.error is None and attempt.validation and attempt.validation.valid:
|
||||
response.success = True
|
||||
response.result = attempt.validation.parsed
|
||||
break
|
||||
|
||||
# Check if the error is non-retryable — stop immediately
|
||||
if not _is_retryable(attempt.error):
|
||||
attempt.retryable = False
|
||||
logger.warning(
|
||||
"Non-retryable error for doc %s: %s — stopping retries",
|
||||
document_id or "unknown",
|
||||
attempt.error,
|
||||
)
|
||||
break
|
||||
|
||||
if attempt_num < self._max_retries:
|
||||
delay = _compute_backoff(
|
||||
attempt_num,
|
||||
self._base_delay,
|
||||
self._max_delay,
|
||||
self._backoff_multiplier,
|
||||
)
|
||||
logger.warning(
|
||||
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||
attempt_num + 1,
|
||||
self._max_retries + 1,
|
||||
document_id or "unknown",
|
||||
attempt.error or "validation failed",
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
|
||||
return response
|
||||
|
||||
async def _call_ollama(
|
||||
self,
|
||||
prompts: dict[str, str],
|
||||
json_schema: dict[str, object],
|
||||
document_text: str = "",
|
||||
) -> ExtractionAttempt:
|
||||
"""Make a single call to the Ollama /api/chat endpoint."""
|
||||
attempt = ExtractionAttempt(model=self._config.model)
|
||||
start = time.monotonic()
|
||||
|
||||
payload = {
|
||||
"model": self._config.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": prompts["system"]},
|
||||
{"role": "user", "content": prompts["user"]},
|
||||
],
|
||||
"format": json_schema,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(
|
||||
f"{self._config.base_url}/api/chat",
|
||||
json=payload,
|
||||
)
|
||||
_ = resp.raise_for_status()
|
||||
except httpx.TimeoutException:
|
||||
attempt.error = "timeout"
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
except httpx.HTTPStatusError as exc:
|
||||
attempt.error = f"http_{exc.response.status_code}"
|
||||
attempt.retryable = _is_retryable(attempt.error)
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
except httpx.HTTPError as exc:
|
||||
attempt.error = f"connection_error: {exc}"
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Parse the Ollama response envelope
|
||||
try:
|
||||
body: dict[str, object] = resp.json()
|
||||
except json.JSONDecodeError:
|
||||
attempt.error = "invalid_response_json"
|
||||
attempt.raw_output = resp.text
|
||||
return attempt
|
||||
|
||||
msg = body.get("message")
|
||||
content: str = msg.get("content", "") if isinstance(msg, dict) else ""
|
||||
attempt.raw_output = content
|
||||
|
||||
if not content:
|
||||
attempt.error = "empty_model_response"
|
||||
return attempt
|
||||
|
||||
# Validate against extraction schema
|
||||
attempt.validation = validate_extraction(content, document_text=document_text)
|
||||
if not attempt.validation.valid:
|
||||
attempt.error = "; ".join(attempt.validation.errors)
|
||||
|
||||
return attempt
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Extractor worker entrypoint - polls Redis for extraction jobs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.worker import persist_extraction
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import setup_logging
|
||||
from services.shared.redis_keys import QUEUE_EXTRACTION, queue_key
|
||||
|
||||
logger = logging.getLogger("extractor_main")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
|
||||
|
||||
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
|
||||
minio_client = Minio(
|
||||
config.minio.endpoint,
|
||||
access_key=config.minio.access_key,
|
||||
secret_key=config.minio.secret_key,
|
||||
secure=config.minio.secure,
|
||||
)
|
||||
ollama = OllamaClient(config.ollama)
|
||||
|
||||
import json
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_EXTRACTION)
|
||||
logger.info("Extractor worker started, polling %s", queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await redis_client.lpop(queue)
|
||||
if raw is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
payload = raw
|
||||
job = json.loads(payload)
|
||||
document_id = job.get("document_id", "")
|
||||
ticker = job.get("ticker", "")
|
||||
text = job.get("text", "")
|
||||
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
try:
|
||||
extraction_response = await ollama.extract(text)
|
||||
await persist_extraction(
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
document_id=document_id,
|
||||
ticker=ticker,
|
||||
extraction_response=extraction_response,
|
||||
document_text_length=len(text),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Extraction failed for doc %s", document_id)
|
||||
finally:
|
||||
await pool.close()
|
||||
await redis_client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Model performance metrics collection and persistence.
|
||||
|
||||
Tracks extraction success/failure rates, latency percentiles, retry counts,
|
||||
validation error distributions, confidence scores, and token usage estimates.
|
||||
Metrics are persisted to PostgreSQL for operational dashboards and published
|
||||
to the analytical lake for Trino/Superset queries.
|
||||
|
||||
Requirements: 5.2, 5.4, 12.1, 12.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.extractor.client import ExtractionResponse
|
||||
|
||||
logger = logging.getLogger("extractor_metrics")
|
||||
|
||||
# Rough token estimate: ~4 chars per token for English text
|
||||
_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionMetrics:
|
||||
"""Metrics extracted from a single extraction run."""
|
||||
|
||||
document_id: str = ""
|
||||
ticker: str = ""
|
||||
model_name: str = ""
|
||||
prompt_version: str = ""
|
||||
schema_version: str = ""
|
||||
success: bool = False
|
||||
attempt_count: int = 0
|
||||
total_duration_ms: int = 0
|
||||
first_attempt_duration_ms: int = 0
|
||||
final_attempt_duration_ms: int = 0
|
||||
confidence: float = 0.0
|
||||
validation_status: str = "unknown"
|
||||
validation_error_count: int = 0
|
||||
validation_warning_count: int = 0
|
||||
validation_errors: list[str] = field(default_factory=list)
|
||||
retry_count: int = 0
|
||||
input_token_estimate: int = 0
|
||||
output_token_estimate: int = 0
|
||||
company_count: int = 0
|
||||
recorded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
def collect_metrics(
|
||||
extraction_response: ExtractionResponse,
|
||||
*,
|
||||
document_id: str = "",
|
||||
ticker: str = "",
|
||||
document_text_length: int = 0,
|
||||
) -> ExtractionMetrics:
|
||||
"""Collect metrics from an ExtractionResponse.
|
||||
|
||||
Args:
|
||||
extraction_response: The full response from OllamaClient.extract().
|
||||
document_id: UUID of the source document.
|
||||
ticker: Primary ticker symbol.
|
||||
document_text_length: Length of the input document text in characters.
|
||||
|
||||
Returns:
|
||||
An ExtractionMetrics dataclass with all computed fields.
|
||||
"""
|
||||
attempts = extraction_response.attempts
|
||||
first_dur = attempts[0].duration_ms if attempts else 0
|
||||
final_dur = attempts[-1].duration_ms if attempts else 0
|
||||
|
||||
# Gather validation info from the final attempt
|
||||
final_attempt = attempts[-1] if attempts else None
|
||||
val_errors: list[str] = []
|
||||
val_warnings: list[str] = []
|
||||
if final_attempt and final_attempt.validation:
|
||||
val_errors = final_attempt.validation.errors
|
||||
val_warnings = final_attempt.validation.warnings
|
||||
|
||||
# Determine validation status
|
||||
if extraction_response.success:
|
||||
validation_status = "valid"
|
||||
elif attempts:
|
||||
validation_status = "failed"
|
||||
else:
|
||||
validation_status = "unknown"
|
||||
|
||||
# Confidence from the result, or 0 if failed
|
||||
confidence = 0.0
|
||||
company_count = 0
|
||||
if extraction_response.result:
|
||||
confidence = extraction_response.result.confidence
|
||||
company_count = len(extraction_response.result.companies)
|
||||
|
||||
# Token estimates
|
||||
input_tokens = document_text_length // _CHARS_PER_TOKEN if document_text_length > 0 else 0
|
||||
output_tokens = 0
|
||||
if final_attempt and final_attempt.raw_output:
|
||||
output_tokens = len(final_attempt.raw_output) // _CHARS_PER_TOKEN
|
||||
|
||||
return ExtractionMetrics(
|
||||
document_id=document_id,
|
||||
ticker=ticker,
|
||||
model_name=extraction_response.model,
|
||||
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
||||
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
||||
success=extraction_response.success,
|
||||
attempt_count=len(attempts),
|
||||
total_duration_ms=extraction_response.total_duration_ms,
|
||||
first_attempt_duration_ms=first_dur,
|
||||
final_attempt_duration_ms=final_dur,
|
||||
confidence=confidence,
|
||||
validation_status=validation_status,
|
||||
validation_error_count=len(val_errors),
|
||||
validation_warning_count=len(val_warnings),
|
||||
validation_errors=val_errors,
|
||||
retry_count=max(0, len(attempts) - 1),
|
||||
input_token_estimate=input_tokens,
|
||||
output_token_estimate=output_tokens,
|
||||
company_count=company_count,
|
||||
)
|
||||
|
||||
|
||||
async def persist_metrics(
|
||||
pool: asyncpg.Pool,
|
||||
metrics: ExtractionMetrics,
|
||||
) -> str:
|
||||
"""Persist extraction metrics to the model_performance_metrics table.
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL connection pool.
|
||||
metrics: Collected metrics from an extraction run.
|
||||
|
||||
Returns:
|
||||
The UUID of the inserted metrics row.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO model_performance_metrics
|
||||
(document_id, ticker, model_name, prompt_version, schema_version,
|
||||
success, attempt_count, total_duration_ms,
|
||||
first_attempt_duration_ms, final_attempt_duration_ms,
|
||||
confidence, validation_status, validation_error_count,
|
||||
validation_warning_count, validation_errors, retry_count,
|
||||
input_token_estimate, output_token_estimate, company_count,
|
||||
recorded_at)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
||||
$11, $12, $13, $14, $15::jsonb, $16, $17, $18, $19, $20)
|
||||
RETURNING id""",
|
||||
metrics.document_id,
|
||||
metrics.ticker,
|
||||
metrics.model_name,
|
||||
metrics.prompt_version,
|
||||
metrics.schema_version,
|
||||
metrics.success,
|
||||
metrics.attempt_count,
|
||||
metrics.total_duration_ms,
|
||||
metrics.first_attempt_duration_ms,
|
||||
metrics.final_attempt_duration_ms,
|
||||
metrics.confidence,
|
||||
metrics.validation_status,
|
||||
metrics.validation_error_count,
|
||||
metrics.validation_warning_count,
|
||||
json.dumps(metrics.validation_errors),
|
||||
metrics.retry_count,
|
||||
metrics.input_token_estimate,
|
||||
metrics.output_token_estimate,
|
||||
metrics.company_count,
|
||||
metrics.recorded_at,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted extraction metrics %s for doc %s: success=%s duration=%dms retries=%d",
|
||||
row_id, metrics.document_id, metrics.success,
|
||||
metrics.total_duration_ms, metrics.retry_count,
|
||||
)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
async def get_model_performance_summary(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
model_name: str | None = None,
|
||||
hours: int = 24,
|
||||
) -> dict[str, object]:
|
||||
"""Query aggregated model performance metrics for dashboards.
|
||||
|
||||
Returns a summary dict with success rate, avg latency, retry rate,
|
||||
confidence distribution, and error breakdown for the given time window.
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL connection pool.
|
||||
model_name: Optional filter by model name.
|
||||
hours: Lookback window in hours (default 24).
|
||||
|
||||
Returns:
|
||||
Dict with aggregated performance metrics.
|
||||
"""
|
||||
model_filter = "AND model_name = $2" if model_name else ""
|
||||
params: list[object] = [hours]
|
||||
if model_name:
|
||||
params.append(model_name)
|
||||
|
||||
row = await pool.fetchrow(
|
||||
f"""SELECT
|
||||
COUNT(*) AS total_extractions,
|
||||
COUNT(*) FILTER (WHERE success) AS successful,
|
||||
COUNT(*) FILTER (WHERE NOT success) AS failed,
|
||||
ROUND(AVG(total_duration_ms)::numeric, 1) AS avg_duration_ms,
|
||||
ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p50_duration_ms,
|
||||
ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p95_duration_ms,
|
||||
ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p99_duration_ms,
|
||||
ROUND(AVG(retry_count)::numeric, 2) AS avg_retries,
|
||||
ROUND(AVG(confidence)::numeric, 3) AS avg_confidence,
|
||||
SUM(input_token_estimate) AS total_input_tokens,
|
||||
SUM(output_token_estimate) AS total_output_tokens,
|
||||
ROUND(AVG(company_count)::numeric, 2) AS avg_companies_per_doc,
|
||||
ROUND(AVG(validation_error_count)::numeric, 2) AS avg_validation_errors,
|
||||
ROUND(AVG(validation_warning_count)::numeric, 2) AS avg_validation_warnings
|
||||
FROM model_performance_metrics
|
||||
WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1
|
||||
{model_filter}""",
|
||||
*params,
|
||||
)
|
||||
|
||||
if not row or row["total_extractions"] == 0:
|
||||
return {"total_extractions": 0, "success_rate": 0.0}
|
||||
|
||||
total = row["total_extractions"]
|
||||
successful = row["successful"]
|
||||
|
||||
return {
|
||||
"total_extractions": total,
|
||||
"successful": successful,
|
||||
"failed": row["failed"],
|
||||
"success_rate": round(successful / total, 4) if total > 0 else 0.0,
|
||||
"avg_duration_ms": float(row["avg_duration_ms"] or 0),
|
||||
"p50_duration_ms": float(row["p50_duration_ms"] or 0),
|
||||
"p95_duration_ms": float(row["p95_duration_ms"] or 0),
|
||||
"p99_duration_ms": float(row["p99_duration_ms"] or 0),
|
||||
"avg_retries": float(row["avg_retries"] or 0),
|
||||
"avg_confidence": float(row["avg_confidence"] or 0),
|
||||
"total_input_tokens": int(row["total_input_tokens"] or 0),
|
||||
"total_output_tokens": int(row["total_output_tokens"] or 0),
|
||||
"avg_companies_per_doc": float(row["avg_companies_per_doc"] or 0),
|
||||
"avg_validation_errors": float(row["avg_validation_errors"] or 0),
|
||||
"avg_validation_warnings": float(row["avg_validation_warnings"] or 0),
|
||||
"hours": hours,
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Extraction prompt templates with anti-hallucination instructions.
|
||||
|
||||
Builds structured prompts for Ollama document intelligence extraction.
|
||||
Each prompt includes the target JSON schema, anti-hallucination rules,
|
||||
and document-type-specific guidance.
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from services.extractor.schemas import generate_json_schema, SCHEMA_VERSION
|
||||
from services.shared.schemas import (
|
||||
DocumentType,
|
||||
)
|
||||
|
||||
PROMPT_VERSION = "document-intel-v1"
|
||||
|
||||
# --- JSON schema for structured output (generated from Pydantic models) ---
|
||||
|
||||
EXTRACTION_JSON_SCHEMA: dict[str, Any] = generate_json_schema()
|
||||
|
||||
# --- Anti-hallucination system prompt ---
|
||||
|
||||
SYSTEM_PROMPT = """\
|
||||
You are a financial document analysis system. You extract structured intelligence \
|
||||
from financial documents into JSON.
|
||||
|
||||
STRICT RULES — VIOLATIONS WILL INVALIDATE YOUR OUTPUT:
|
||||
|
||||
1. ONLY extract information explicitly stated in the document text provided.
|
||||
2. NEVER fabricate facts, quotes, numbers, dates, or company names.
|
||||
3. NEVER infer information that is not directly supported by the text.
|
||||
4. If the document does not mention a company, do NOT include that company.
|
||||
5. If the document is ambiguous about sentiment or impact, use "neutral" or "mixed" \
|
||||
and set confidence lower.
|
||||
6. evidence_spans MUST be short verbatim quotes copied from the document. \
|
||||
Do NOT paraphrase or invent quotes.
|
||||
7. key_facts MUST be directly stated in the document. Do NOT add external knowledge.
|
||||
8. If you are uncertain about any field, lower the confidence score and add a warning \
|
||||
to extraction_warnings.
|
||||
9. If the document text is too short, garbled, or uninformative, return an empty \
|
||||
companies array, set confidence below 0.3, and add "insufficient_content" to warnings.
|
||||
10. Return ONLY valid JSON matching the provided schema. No commentary, no markdown fences."""
|
||||
|
||||
# --- Document-type-specific guidance ---
|
||||
|
||||
_DOCTYPE_GUIDANCE: dict[str, str] = {
|
||||
DocumentType.ARTICLE: (
|
||||
"This is a news article. Focus on reported facts, quoted sources, and stated "
|
||||
"analyst opinions. Distinguish between the journalist's framing and actual "
|
||||
"company developments. Do not treat speculative language as confirmed fact."
|
||||
),
|
||||
DocumentType.FILING: (
|
||||
"This is a regulatory filing (e.g. SEC 10-K, 10-Q, 8-K). Extract concrete "
|
||||
"financial figures, risk factors, and material events as stated. Filings use "
|
||||
"precise legal language — preserve that precision in your extraction."
|
||||
),
|
||||
DocumentType.TRANSCRIPT: (
|
||||
"This is an earnings call or event transcript. Distinguish between management "
|
||||
"forward-looking statements and reported results. Flag forward-looking language "
|
||||
"as lower confidence. Extract specific guidance numbers when stated."
|
||||
),
|
||||
DocumentType.PRESS_RELEASE: (
|
||||
"This is a company press release. Be aware that press releases are promotional. "
|
||||
"Extract stated facts and figures but note that sentiment may be biased positive. "
|
||||
"Look for concrete metrics rather than marketing language."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _get_doctype_guidance(document_type: str) -> str:
|
||||
"""Return document-type-specific extraction guidance."""
|
||||
return _DOCTYPE_GUIDANCE.get(document_type, _DOCTYPE_GUIDANCE[DocumentType.ARTICLE])
|
||||
|
||||
|
||||
# --- Prompt builder ---
|
||||
|
||||
def build_extraction_prompt(
|
||||
document_text: str,
|
||||
document_type: str = DocumentType.ARTICLE,
|
||||
known_tickers: list[str] | None = None,
|
||||
document_id: str = "",
|
||||
) -> dict[str, str]:
|
||||
"""Build system and user prompts for Ollama structured extraction.
|
||||
|
||||
Args:
|
||||
document_text: Normalized text content of the document.
|
||||
document_type: One of the DocumentType enum values.
|
||||
known_tickers: Optional list of tickers the document may reference.
|
||||
Helps the model focus but does NOT mean all tickers are relevant.
|
||||
document_id: Optional document ID for traceability.
|
||||
|
||||
Returns:
|
||||
Dict with 'system' and 'user' prompt strings.
|
||||
"""
|
||||
doctype_guidance = _get_doctype_guidance(document_type)
|
||||
|
||||
ticker_hint = ""
|
||||
if known_tickers:
|
||||
tickers_str = ", ".join(known_tickers)
|
||||
ticker_hint = (
|
||||
f"\nThe following tickers may be referenced in this document: {tickers_str}\n"
|
||||
"Only include a ticker in your output if the document actually discusses that company. "
|
||||
"Do NOT include a ticker just because it appears in this hint."
|
||||
)
|
||||
|
||||
schema_str = json.dumps(EXTRACTION_JSON_SCHEMA, indent=2)
|
||||
|
||||
doc_id_line = f"Document ID: {document_id}\n" if document_id else ""
|
||||
|
||||
user_prompt = f"""\
|
||||
Extract structured intelligence from the following document.
|
||||
|
||||
{doc_id_line}Document type: {document_type}
|
||||
{doctype_guidance}
|
||||
{ticker_hint}
|
||||
Your output MUST be a single JSON object conforming to this schema:
|
||||
{schema_str}
|
||||
|
||||
REMEMBER:
|
||||
- Only extract what is explicitly in the text below.
|
||||
- evidence_spans must be verbatim quotes from the text.
|
||||
- If the text is insufficient, return empty companies and low confidence.
|
||||
- Return ONLY the JSON object. No other text.
|
||||
|
||||
--- DOCUMENT TEXT ---
|
||||
{document_text}
|
||||
--- END DOCUMENT TEXT ---"""
|
||||
|
||||
return {
|
||||
"system": SYSTEM_PROMPT,
|
||||
"user": user_prompt,
|
||||
}
|
||||
|
||||
|
||||
def get_prompt_metadata() -> dict[str, str]:
|
||||
"""Return metadata about the current prompt version for audit trails."""
|
||||
return {
|
||||
"prompt_version": PROMPT_VERSION,
|
||||
"schema_version": SCHEMA_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def get_json_schema() -> dict[str, Any]:
|
||||
"""Return the extraction JSON schema for Ollama structured output format parameter."""
|
||||
return EXTRACTION_JSON_SCHEMA
|
||||
@@ -0,0 +1,250 @@
|
||||
"""Replay dataset loader and runner for deterministic extraction testing.
|
||||
|
||||
Loads archived document fixtures from JSON files, validates their expected
|
||||
extraction outputs against the current schema, and provides a runner that
|
||||
can compare live Ollama extraction results against expected baselines.
|
||||
|
||||
This enables:
|
||||
- Schema regression testing: verify expected outputs still pass validation
|
||||
- Prompt regression testing: detect drift when prompts or schemas change
|
||||
- End-to-end replay: run fixtures through a live Ollama and compare
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from services.extractor.schemas import (
|
||||
ExtractionResult,
|
||||
ValidationReport,
|
||||
get_schema_version,
|
||||
validate_extraction,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("extractor_replay")
|
||||
|
||||
FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent / "tests" / "replay_fixtures"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayFixture:
|
||||
"""A single replay fixture loaded from disk."""
|
||||
|
||||
document_id: str
|
||||
document_type: str
|
||||
document_text: str
|
||||
known_tickers: list[str]
|
||||
expected_extraction: dict[str, Any]
|
||||
metadata: dict[str, str]
|
||||
source_path: str = ""
|
||||
|
||||
@property
|
||||
def expected_result(self) -> ExtractionResult:
|
||||
"""Parse expected_extraction into a validated ExtractionResult."""
|
||||
return ExtractionResult.model_validate(self.expected_extraction)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayValidationResult:
|
||||
"""Result of validating a single fixture against the current schema."""
|
||||
|
||||
fixture_id: str
|
||||
schema_valid: bool = False
|
||||
validation_report: ValidationReport | None = None
|
||||
schema_version: str = ""
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplayComparisonResult:
|
||||
"""Result of comparing a live extraction against the expected baseline."""
|
||||
|
||||
fixture_id: str
|
||||
expected_companies: list[str] = field(default_factory=list)
|
||||
actual_companies: list[str] = field(default_factory=list)
|
||||
companies_match: bool = False
|
||||
expected_sentiment_map: dict[str, str] = field(default_factory=dict)
|
||||
actual_sentiment_map: dict[str, str] = field(default_factory=dict)
|
||||
sentiment_match: bool = False
|
||||
expected_catalyst_map: dict[str, str] = field(default_factory=dict)
|
||||
actual_catalyst_map: dict[str, str] = field(default_factory=dict)
|
||||
catalyst_match: bool = False
|
||||
actual_schema_valid: bool = False
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def load_fixture(path: Path) -> ReplayFixture:
|
||||
"""Load a single replay fixture from a JSON file.
|
||||
|
||||
Args:
|
||||
path: Path to the fixture JSON file.
|
||||
|
||||
Returns:
|
||||
A ReplayFixture with all fields populated.
|
||||
|
||||
Raises:
|
||||
ValueError: If the fixture is missing required fields.
|
||||
json.JSONDecodeError: If the file is not valid JSON.
|
||||
"""
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
required = {"document_id", "document_type", "document_text", "expected_extraction"}
|
||||
missing = required - set(data.keys())
|
||||
if missing:
|
||||
raise ValueError(f"Fixture {path.name} missing required fields: {missing}")
|
||||
|
||||
return ReplayFixture(
|
||||
document_id=data["document_id"],
|
||||
document_type=data["document_type"],
|
||||
document_text=data["document_text"],
|
||||
known_tickers=data.get("known_tickers", []),
|
||||
expected_extraction=data["expected_extraction"],
|
||||
metadata=data.get("metadata", {}),
|
||||
source_path=str(path),
|
||||
)
|
||||
|
||||
|
||||
def load_all_fixtures(fixtures_dir: Path | None = None) -> list[ReplayFixture]:
|
||||
"""Load all replay fixtures from the fixtures directory.
|
||||
|
||||
Args:
|
||||
fixtures_dir: Override path to fixtures directory.
|
||||
Defaults to tests/replay_fixtures/.
|
||||
|
||||
Returns:
|
||||
List of loaded ReplayFixture objects, sorted by document_id.
|
||||
"""
|
||||
directory = fixtures_dir or FIXTURES_DIR
|
||||
if not directory.is_dir():
|
||||
logger.warning("Fixtures directory not found: %s", directory)
|
||||
return []
|
||||
|
||||
fixtures: list[ReplayFixture] = []
|
||||
for path in sorted(directory.glob("*.json")):
|
||||
try:
|
||||
fixture = load_fixture(path)
|
||||
fixtures.append(fixture)
|
||||
except (ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("Skipping invalid fixture %s: %s", path.name, exc)
|
||||
|
||||
logger.info("Loaded %d replay fixtures from %s", len(fixtures), directory)
|
||||
return fixtures
|
||||
|
||||
|
||||
def validate_fixture(fixture: ReplayFixture) -> ReplayValidationResult:
|
||||
"""Validate a fixture's expected extraction against the current schema.
|
||||
|
||||
This is the core deterministic test: the expected output must still
|
||||
pass schema and semantic validation with the current code. If it
|
||||
doesn't, either the fixture is stale or the schema has regressed.
|
||||
|
||||
Args:
|
||||
fixture: The replay fixture to validate.
|
||||
|
||||
Returns:
|
||||
A ReplayValidationResult indicating pass/fail.
|
||||
"""
|
||||
result = ReplayValidationResult(
|
||||
fixture_id=fixture.document_id,
|
||||
schema_version=get_schema_version(),
|
||||
)
|
||||
|
||||
try:
|
||||
report = validate_extraction(
|
||||
fixture.expected_extraction,
|
||||
document_text=fixture.document_text,
|
||||
)
|
||||
result.validation_report = report
|
||||
result.schema_valid = report.valid
|
||||
except Exception as exc: # noqa: BLE001
|
||||
result.error = str(exc)
|
||||
result.schema_valid = False
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def validate_all_fixtures(
|
||||
fixtures_dir: Path | None = None,
|
||||
) -> list[ReplayValidationResult]:
|
||||
"""Load and validate all fixtures against the current schema.
|
||||
|
||||
Args:
|
||||
fixtures_dir: Override path to fixtures directory.
|
||||
|
||||
Returns:
|
||||
List of validation results, one per fixture.
|
||||
"""
|
||||
fixtures = load_all_fixtures(fixtures_dir)
|
||||
return [validate_fixture(f) for f in fixtures]
|
||||
|
||||
|
||||
def compare_extraction(
|
||||
fixture: ReplayFixture,
|
||||
actual_result: ExtractionResult,
|
||||
) -> ReplayComparisonResult:
|
||||
"""Compare a live extraction result against the fixture's expected output.
|
||||
|
||||
Checks structural alignment (same companies detected, same sentiments,
|
||||
same catalyst types) rather than exact string equality, since LLM
|
||||
outputs vary in wording across runs.
|
||||
|
||||
Args:
|
||||
fixture: The replay fixture with expected output.
|
||||
actual_result: The ExtractionResult from a live extraction.
|
||||
|
||||
Returns:
|
||||
A ReplayComparisonResult with match details.
|
||||
"""
|
||||
expected = fixture.expected_result
|
||||
comparison = ReplayComparisonResult(fixture_id=fixture.document_id)
|
||||
|
||||
# Company ticker sets
|
||||
comparison.expected_companies = sorted(c.ticker for c in expected.companies)
|
||||
comparison.actual_companies = sorted(c.ticker for c in actual_result.companies)
|
||||
comparison.companies_match = (
|
||||
set(comparison.expected_companies) == set(comparison.actual_companies)
|
||||
)
|
||||
|
||||
# Sentiment by ticker
|
||||
comparison.expected_sentiment_map = {
|
||||
c.ticker: c.sentiment for c in expected.companies
|
||||
}
|
||||
comparison.actual_sentiment_map = {
|
||||
c.ticker: c.sentiment for c in actual_result.companies
|
||||
}
|
||||
comparison.sentiment_match = (
|
||||
comparison.expected_sentiment_map == comparison.actual_sentiment_map
|
||||
)
|
||||
|
||||
# Catalyst type by ticker
|
||||
comparison.expected_catalyst_map = {
|
||||
c.ticker: c.catalyst_type for c in expected.companies
|
||||
}
|
||||
comparison.actual_catalyst_map = {
|
||||
c.ticker: c.catalyst_type for c in actual_result.companies
|
||||
}
|
||||
comparison.catalyst_match = (
|
||||
comparison.expected_catalyst_map == comparison.actual_catalyst_map
|
||||
)
|
||||
|
||||
# Schema validity of actual result
|
||||
actual_report = validate_extraction(
|
||||
actual_result.model_dump(mode="json"),
|
||||
document_text=fixture.document_text,
|
||||
)
|
||||
comparison.actual_schema_valid = actual_report.valid
|
||||
if actual_report.warnings:
|
||||
comparison.warnings = actual_report.warnings
|
||||
|
||||
if not comparison.companies_match:
|
||||
comparison.warnings.append(
|
||||
f"company_mismatch: expected={comparison.expected_companies} actual={comparison.actual_companies}"
|
||||
)
|
||||
|
||||
return comparison
|
||||
@@ -0,0 +1,316 @@
|
||||
"""JSON schema definitions for document intelligence extraction.
|
||||
|
||||
Generates Ollama-compatible JSON schemas from Pydantic models so the
|
||||
extraction contract stays in sync with the shared data models. Also
|
||||
provides schema validation and semantic validation helpers.
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from services.shared.schemas import (
|
||||
CatalystType,
|
||||
Sentiment,
|
||||
)
|
||||
|
||||
SCHEMA_VERSION = "2.0.0"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic model that mirrors the Ollama extraction output contract.
|
||||
# This is the *response* shape we ask the model to produce — it intentionally
|
||||
# omits server-side fields like document_id, source_credibility, and model
|
||||
# metadata that are attached after extraction.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CompanyExtractionItem(BaseModel):
|
||||
"""Per-company extraction output expected from the model.
|
||||
|
||||
All fields are required (no defaults) so the generated JSON schema
|
||||
forces the model to produce every field explicitly.
|
||||
"""
|
||||
|
||||
ticker: str = Field(description="Stock ticker symbol mentioned in the document.")
|
||||
company_name: str = Field(description="Full company name as referenced in the document.")
|
||||
relevance: float = Field(
|
||||
ge=0,
|
||||
le=1,
|
||||
description="How relevant the document is to this company. 0=tangential, 1=primary subject.",
|
||||
)
|
||||
sentiment: Sentiment = Field(description="Overall sentiment toward this company in the document.")
|
||||
impact_score: float = Field(
|
||||
ge=0,
|
||||
le=1,
|
||||
description="Estimated magnitude of impact. 0=negligible, 1=highly material.",
|
||||
)
|
||||
impact_horizon: str = Field(
|
||||
description="One of: intraday, 1d, 1d_7d, 1d_30d, 30d_90d, 90d_plus",
|
||||
)
|
||||
catalyst_type: CatalystType = Field(description="Primary catalyst category.")
|
||||
key_facts: list[str] = Field(
|
||||
description="Facts explicitly stated in the document. Do NOT infer or fabricate.",
|
||||
)
|
||||
risks: list[str] = Field(
|
||||
description="Risks explicitly mentioned in the document.",
|
||||
)
|
||||
evidence_spans: list[str] = Field(
|
||||
description="Short verbatim quotes from the document supporting the analysis.",
|
||||
)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
"""Top-level structured output the model must return.
|
||||
|
||||
All fields are required (no defaults) so the generated JSON schema
|
||||
forces the model to produce every field explicitly.
|
||||
"""
|
||||
|
||||
summary: str = Field(
|
||||
description="A concise 1-3 sentence summary of the document's main point.",
|
||||
)
|
||||
companies: list[CompanyExtractionItem] = Field(
|
||||
description="Per-company intelligence extracted from the document.",
|
||||
)
|
||||
macro_themes: list[str] = Field(
|
||||
description="Broad economic or market themes mentioned (e.g. rates, inflation, ai_capex).",
|
||||
)
|
||||
novelty_score: float = Field(
|
||||
ge=0,
|
||||
le=1,
|
||||
description="How novel or surprising the information is. 0=routine, 1=highly novel.",
|
||||
)
|
||||
confidence: float = Field(
|
||||
ge=0,
|
||||
le=1,
|
||||
description="Model confidence in the accuracy of this extraction. Lower if text is ambiguous.",
|
||||
)
|
||||
extraction_warnings: list[str] = Field(
|
||||
description="Any issues encountered: ambiguous_ticker, incomplete_text, low_confidence, etc.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def generate_json_schema() -> dict[str, Any]:
|
||||
"""Generate the JSON schema from the Pydantic model.
|
||||
|
||||
Returns a plain JSON Schema dict suitable for Ollama's ``format``
|
||||
parameter. Pydantic ``$defs`` are inlined so the schema is
|
||||
self-contained.
|
||||
"""
|
||||
raw = ExtractionResult.model_json_schema()
|
||||
# Inline $defs so the schema is flat and Ollama-friendly
|
||||
return _inline_defs(raw)
|
||||
|
||||
|
||||
def get_schema_version() -> str:
|
||||
"""Return the current schema version string."""
|
||||
return SCHEMA_VERSION
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ValidationReport(BaseModel):
|
||||
"""Result of validating a raw model response."""
|
||||
|
||||
valid: bool = False
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
parsed: ExtractionResult | None = None
|
||||
|
||||
|
||||
def validate_extraction(
|
||||
raw_json: str | dict[str, Any],
|
||||
*,
|
||||
document_text: str = "",
|
||||
) -> ValidationReport:
|
||||
"""Validate raw model output against the extraction schema.
|
||||
|
||||
Performs structural (JSON / Pydantic) validation followed by semantic
|
||||
checks that catch hallucination indicators, cross-field inconsistencies,
|
||||
and data-quality issues.
|
||||
|
||||
Args:
|
||||
raw_json: Either a JSON string or an already-parsed dict.
|
||||
document_text: Optional original document text used for evidence
|
||||
span verification.
|
||||
|
||||
Returns:
|
||||
A ``ValidationReport`` with parsed result on success.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
# --- Parse JSON string if needed ---
|
||||
if isinstance(raw_json, str):
|
||||
try:
|
||||
data = json.loads(raw_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"])
|
||||
else:
|
||||
data = raw_json
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return ValidationReport(valid=False, errors=["Expected a JSON object at top level."])
|
||||
|
||||
# --- Pydantic structural validation ---
|
||||
try:
|
||||
result = ExtractionResult.model_validate(data)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return ValidationReport(valid=False, errors=[f"Schema validation failed: {exc}"])
|
||||
|
||||
# --- Semantic checks ---
|
||||
sem_errors, sem_warnings = _semantic_checks(result, document_text)
|
||||
errors.extend(sem_errors)
|
||||
warnings.extend(sem_warnings)
|
||||
|
||||
# Semantic errors make the report invalid — the caller should retry.
|
||||
valid = len(errors) == 0
|
||||
return ValidationReport(
|
||||
valid=valid,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
parsed=result,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Known valid impact horizons
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_IMPACT_HORIZONS = frozenset({
|
||||
"intraday",
|
||||
"1d",
|
||||
"1d_7d",
|
||||
"1d_30d",
|
||||
"30d_90d",
|
||||
"90d_plus",
|
||||
})
|
||||
|
||||
# Ticker: 1-5 uppercase letters (covers NYSE, NASDAQ, etc.)
|
||||
_TICKER_RE = re.compile(r"^[A-Z]{1,5}$")
|
||||
|
||||
# Evidence span length bounds (characters)
|
||||
_MIN_EVIDENCE_LEN = 8
|
||||
_MAX_EVIDENCE_LEN = 500
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Semantic validation rules
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _semantic_checks(
|
||||
result: ExtractionResult,
|
||||
document_text: str = "",
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Run semantic checks on a parsed extraction.
|
||||
|
||||
Returns a tuple of (errors, warnings). Errors are issues severe enough
|
||||
to warrant a retry; warnings are informational.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
|
||||
# --- Top-level checks ---
|
||||
if not result.summary:
|
||||
warnings.append("empty_summary")
|
||||
|
||||
if result.confidence < 0.3 and len(result.companies) > 0:
|
||||
warnings.append("low_confidence_with_companies")
|
||||
|
||||
# Duplicate tickers across company entries
|
||||
tickers_seen: list[str] = []
|
||||
for comp in result.companies:
|
||||
if comp.ticker in tickers_seen:
|
||||
errors.append(f"duplicate_ticker_{comp.ticker}")
|
||||
tickers_seen.append(comp.ticker)
|
||||
|
||||
# --- Per-company checks ---
|
||||
for comp in result.companies:
|
||||
tag = comp.ticker or "unknown"
|
||||
|
||||
# Ticker format
|
||||
if not comp.ticker:
|
||||
errors.append("company_missing_ticker")
|
||||
elif not _TICKER_RE.match(comp.ticker):
|
||||
warnings.append(f"invalid_ticker_format_{tag}")
|
||||
|
||||
# Impact horizon must be a known value
|
||||
if comp.impact_horizon not in VALID_IMPACT_HORIZONS:
|
||||
errors.append(f"invalid_impact_horizon_{comp.impact_horizon}_for_{tag}")
|
||||
|
||||
# Evidence spans
|
||||
if not comp.evidence_spans:
|
||||
warnings.append(f"no_evidence_spans_for_{tag}")
|
||||
else:
|
||||
for idx, span in enumerate(comp.evidence_spans):
|
||||
if len(span) < _MIN_EVIDENCE_LEN:
|
||||
warnings.append(f"evidence_span_too_short_for_{tag}_{idx}")
|
||||
if len(span) > _MAX_EVIDENCE_LEN:
|
||||
warnings.append(f"evidence_span_too_long_for_{tag}_{idx}")
|
||||
|
||||
# Cross-field: high impact but no facts
|
||||
if not comp.key_facts and comp.impact_score > 0.5:
|
||||
warnings.append(f"high_impact_no_facts_for_{tag}")
|
||||
|
||||
# Cross-field: very low relevance
|
||||
if comp.relevance < 0.2:
|
||||
warnings.append(f"very_low_relevance_for_{tag}")
|
||||
|
||||
# Cross-field: strong sentiment but low impact
|
||||
if comp.sentiment in (Sentiment.POSITIVE, Sentiment.NEGATIVE) and comp.impact_score < 0.1:
|
||||
warnings.append(f"strong_sentiment_low_impact_for_{tag}")
|
||||
|
||||
# --- Evidence grounding check (when source text is available) ---
|
||||
if document_text:
|
||||
doc_lower = document_text.lower()
|
||||
for comp in result.companies:
|
||||
for idx, span in enumerate(comp.evidence_spans):
|
||||
if span.lower() not in doc_lower:
|
||||
warnings.append(
|
||||
f"evidence_span_not_found_in_document_for_{comp.ticker or 'unknown'}_{idx}"
|
||||
)
|
||||
|
||||
return errors, warnings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively inline ``$defs`` / ``$ref`` so the schema is self-contained."""
|
||||
defs = schema.pop("$defs", {})
|
||||
return _resolve_refs(schema, defs)
|
||||
|
||||
|
||||
def _resolve_refs(node: Any, defs: dict[str, Any]) -> Any:
|
||||
"""Walk the schema tree and replace ``$ref`` pointers with their definitions."""
|
||||
if isinstance(node, dict):
|
||||
if "$ref" in node:
|
||||
ref_path = node["$ref"] # e.g. "#/$defs/CompanyExtractionItem"
|
||||
ref_name = ref_path.rsplit("/", 1)[-1]
|
||||
if ref_name in defs:
|
||||
resolved = defs[ref_name].copy()
|
||||
# The resolved def may itself contain refs
|
||||
return _resolve_refs(resolved, defs)
|
||||
return node # unresolvable ref, leave as-is
|
||||
return {k: _resolve_refs(v, defs) for k, v in node.items()}
|
||||
if isinstance(node, list):
|
||||
return [_resolve_refs(item, defs) for item in node]
|
||||
return node
|
||||
@@ -1 +1,291 @@
|
||||
"""Extraction worker - sends documents to Ollama for structured intelligence extraction."""
|
||||
"""Extraction worker - sends documents to Ollama for structured intelligence extraction.
|
||||
|
||||
Orchestrates the full extraction pipeline for a single document:
|
||||
1. Calls OllamaClient to get structured extraction
|
||||
2. Uploads prompts, raw outputs, and validation reports to MinIO
|
||||
3. Persists the final intelligence object and per-company impact records to PostgreSQL
|
||||
4. Updates document status
|
||||
|
||||
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 9.1, 9.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.extractor.client import ExtractionResponse
|
||||
from services.extractor.metrics import collect_metrics, persist_metrics
|
||||
from services.shared.metadata import (
|
||||
persist_document_impact,
|
||||
persist_document_intelligence,
|
||||
update_document_status,
|
||||
)
|
||||
from services.shared.storage import (
|
||||
upload_extraction_intelligence,
|
||||
upload_extraction_prompt,
|
||||
upload_extraction_raw_output,
|
||||
upload_extraction_validation,
|
||||
)
|
||||
from services.shared.logging import Span
|
||||
from services.shared.metrics import (
|
||||
EXTRACTION_ATTEMPTS,
|
||||
EXTRACTION_CONFIDENCE,
|
||||
EXTRACTION_DURATION,
|
||||
EXTRACTION_JOBS_TOTAL,
|
||||
EXTRACTION_RETRIES,
|
||||
EXTRACTION_TOKEN_ESTIMATE,
|
||||
EXTRACTION_VALIDATION_ERRORS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("extractor_worker")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionPersistResult:
|
||||
"""Result of persisting an extraction to storage and database."""
|
||||
|
||||
intelligence_id: str | None = None
|
||||
prompt_ref: str | None = None
|
||||
raw_output_ref: str | None = None
|
||||
validation_ref: str | None = None
|
||||
intelligence_ref: str | None = None
|
||||
impact_ids: list[str] | None = None
|
||||
metrics_id: str | None = None
|
||||
success: bool = False
|
||||
|
||||
|
||||
async def persist_extraction(
|
||||
*,
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
document_id: str,
|
||||
ticker: str,
|
||||
extraction_response: ExtractionResponse,
|
||||
company_id_map: dict[str, str] | None = None,
|
||||
source_credibility: float = 0.5,
|
||||
timestamp: datetime | None = None,
|
||||
document_text_length: int = 0,
|
||||
) -> ExtractionPersistResult:
|
||||
"""Persist all extraction artifacts to MinIO and PostgreSQL.
|
||||
|
||||
Uploads prompts, raw model outputs, validation reports, and the final
|
||||
intelligence object to MinIO. Persists the intelligence record and
|
||||
per-company impact records to PostgreSQL. Updates document status.
|
||||
Also collects and persists model performance metrics.
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL connection pool.
|
||||
minio_client: MinIO client.
|
||||
document_id: UUID of the source document.
|
||||
ticker: Primary ticker for path construction.
|
||||
extraction_response: Full response from OllamaClient.extract().
|
||||
company_id_map: Optional mapping of ticker -> company UUID for impact records.
|
||||
source_credibility: Credibility score to attach to the intelligence record.
|
||||
timestamp: Override timestamp for MinIO paths (defaults to UTC now).
|
||||
document_text_length: Length of the input document text for token estimation.
|
||||
|
||||
Returns:
|
||||
ExtractionPersistResult with references to all persisted artifacts.
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
result = ExtractionPersistResult()
|
||||
company_id_map = company_id_map or {}
|
||||
|
||||
# 1. Upload prompt metadata to MinIO
|
||||
prompt_payload = json.dumps({
|
||||
"prompt_metadata": extraction_response.prompt_metadata,
|
||||
"model": extraction_response.model,
|
||||
}, indent=2).encode()
|
||||
result.prompt_ref = upload_extraction_prompt(
|
||||
minio_client, ticker, document_id, prompt_payload, timestamp=ts,
|
||||
)
|
||||
|
||||
# 2. Upload raw outputs for each attempt
|
||||
attempts_data: list[dict[str, object]] = []
|
||||
for idx, attempt in enumerate(extraction_response.attempts):
|
||||
attempt_record: dict[str, object] = {
|
||||
"attempt_index": idx,
|
||||
"raw_output": attempt.raw_output,
|
||||
"error": attempt.error,
|
||||
"duration_ms": attempt.duration_ms,
|
||||
"model": attempt.model,
|
||||
"retryable": attempt.retryable,
|
||||
}
|
||||
if attempt.validation:
|
||||
attempt_record["validation"] = {
|
||||
"valid": attempt.validation.valid,
|
||||
"errors": attempt.validation.errors,
|
||||
"warnings": attempt.validation.warnings,
|
||||
}
|
||||
attempts_data.append(attempt_record)
|
||||
|
||||
raw_output_payload = json.dumps({
|
||||
"document_id": document_id,
|
||||
"attempts": attempts_data,
|
||||
"total_duration_ms": extraction_response.total_duration_ms,
|
||||
"success": extraction_response.success,
|
||||
}, indent=2).encode()
|
||||
result.raw_output_ref = upload_extraction_raw_output(
|
||||
minio_client, ticker, document_id, raw_output_payload, timestamp=ts,
|
||||
)
|
||||
|
||||
# 3. Upload validation report
|
||||
final_attempt = extraction_response.attempts[-1] if extraction_response.attempts else None
|
||||
validation_payload = json.dumps({
|
||||
"document_id": document_id,
|
||||
"success": extraction_response.success,
|
||||
"attempt_count": len(extraction_response.attempts),
|
||||
"final_validation": {
|
||||
"valid": final_attempt.validation.valid if final_attempt and final_attempt.validation else False,
|
||||
"errors": final_attempt.validation.errors if final_attempt and final_attempt.validation else [],
|
||||
"warnings": final_attempt.validation.warnings if final_attempt and final_attempt.validation else [],
|
||||
} if final_attempt else None,
|
||||
}, indent=2).encode()
|
||||
result.validation_ref = upload_extraction_validation(
|
||||
minio_client, ticker, document_id, validation_payload, timestamp=ts,
|
||||
)
|
||||
|
||||
# 4. Determine validation status and persist intelligence
|
||||
if extraction_response.success and extraction_response.result:
|
||||
extraction = extraction_response.result
|
||||
validation_status = "valid"
|
||||
validation_errors: list[str] = []
|
||||
|
||||
# Upload final intelligence object to MinIO
|
||||
intelligence_payload = json.dumps(
|
||||
extraction.model_dump(mode="json"), indent=2,
|
||||
).encode()
|
||||
result.intelligence_ref = upload_extraction_intelligence(
|
||||
minio_client, ticker, document_id, intelligence_payload, timestamp=ts,
|
||||
)
|
||||
|
||||
# Persist to PostgreSQL
|
||||
intel_id = await persist_document_intelligence(
|
||||
pool,
|
||||
document_id=document_id,
|
||||
summary=extraction.summary,
|
||||
macro_themes=extraction.macro_themes,
|
||||
novelty_score=extraction.novelty_score,
|
||||
source_credibility=source_credibility,
|
||||
extraction_warnings=extraction.extraction_warnings,
|
||||
confidence=extraction.confidence,
|
||||
model_provider="ollama",
|
||||
model_name=extraction_response.model,
|
||||
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
||||
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
||||
raw_output_ref=result.raw_output_ref,
|
||||
prompt_ref=result.prompt_ref,
|
||||
validation_status=validation_status,
|
||||
validation_errors=validation_errors,
|
||||
retry_count=len(extraction_response.attempts) - 1,
|
||||
)
|
||||
result.intelligence_id = intel_id
|
||||
|
||||
# Persist per-company impact records
|
||||
result.impact_ids = []
|
||||
for company in extraction.companies:
|
||||
cid = company_id_map.get(company.ticker)
|
||||
if not cid:
|
||||
logger.warning(
|
||||
"No company_id for ticker %s in doc %s, skipping impact record",
|
||||
company.ticker, document_id,
|
||||
)
|
||||
continue
|
||||
impact_id = await persist_document_impact(
|
||||
pool,
|
||||
intelligence_id=intel_id,
|
||||
company_id=cid,
|
||||
ticker=company.ticker,
|
||||
relevance=company.relevance,
|
||||
sentiment=company.sentiment,
|
||||
impact_score=company.impact_score,
|
||||
impact_horizon=company.impact_horizon,
|
||||
catalyst_type=company.catalyst_type,
|
||||
key_facts=company.key_facts,
|
||||
risks=company.risks,
|
||||
evidence_spans=company.evidence_spans,
|
||||
)
|
||||
result.impact_ids.append(impact_id)
|
||||
|
||||
await update_document_status(pool, document_id=document_id, status="extracted")
|
||||
result.success = True
|
||||
logger.info(
|
||||
"Extraction persisted for doc %s: intel=%s, impacts=%d",
|
||||
document_id, intel_id, len(result.impact_ids),
|
||||
)
|
||||
else:
|
||||
# Failed extraction — still persist the attempt data
|
||||
all_errors: list[str] = []
|
||||
for attempt in extraction_response.attempts:
|
||||
if attempt.error:
|
||||
all_errors.append(attempt.error)
|
||||
|
||||
intel_id = await persist_document_intelligence(
|
||||
pool,
|
||||
document_id=document_id,
|
||||
summary="",
|
||||
macro_themes=[],
|
||||
novelty_score=0.0,
|
||||
source_credibility=source_credibility,
|
||||
extraction_warnings=["extraction_failed"],
|
||||
confidence=0.0,
|
||||
model_provider="ollama",
|
||||
model_name=extraction_response.model,
|
||||
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
|
||||
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
|
||||
raw_output_ref=result.raw_output_ref,
|
||||
prompt_ref=result.prompt_ref,
|
||||
validation_status="failed",
|
||||
validation_errors=all_errors,
|
||||
retry_count=len(extraction_response.attempts),
|
||||
)
|
||||
result.intelligence_id = intel_id
|
||||
|
||||
await update_document_status(pool, document_id=document_id, status="extraction_failed")
|
||||
logger.warning(
|
||||
"Extraction failed for doc %s after %d attempts: %s",
|
||||
document_id, len(extraction_response.attempts), "; ".join(all_errors),
|
||||
)
|
||||
|
||||
# Collect and persist model performance metrics
|
||||
try:
|
||||
metrics = collect_metrics(
|
||||
extraction_response,
|
||||
document_id=document_id,
|
||||
ticker=ticker,
|
||||
document_text_length=document_text_length,
|
||||
)
|
||||
metrics.recorded_at = ts
|
||||
metrics_id = await persist_metrics(pool, metrics)
|
||||
result.metrics_id = metrics_id
|
||||
except Exception:
|
||||
logger.exception("Failed to persist extraction metrics for doc %s", document_id)
|
||||
|
||||
# Prometheus metrics
|
||||
EXTRACTION_ATTEMPTS.inc(len(extraction_response.attempts))
|
||||
EXTRACTION_DURATION.observe(extraction_response.total_duration_ms / 1000.0)
|
||||
retry_count = max(0, len(extraction_response.attempts) - 1)
|
||||
if retry_count > 0:
|
||||
EXTRACTION_RETRIES.inc(retry_count)
|
||||
if extraction_response.success:
|
||||
EXTRACTION_JOBS_TOTAL.labels(status="success").inc()
|
||||
if extraction_response.result:
|
||||
EXTRACTION_CONFIDENCE.observe(extraction_response.result.confidence)
|
||||
else:
|
||||
EXTRACTION_JOBS_TOTAL.labels(status="failed").inc()
|
||||
# Count validation errors from final attempt
|
||||
final = extraction_response.attempts[-1] if extraction_response.attempts else None
|
||||
if final and final.validation and final.validation.errors:
|
||||
EXTRACTION_VALIDATION_ERRORS.inc(len(final.validation.errors))
|
||||
# Token estimates
|
||||
if document_text_length > 0:
|
||||
EXTRACTION_TOKEN_ESTIMATE.labels(direction="input").inc(document_text_length // 4)
|
||||
if final and final.raw_output:
|
||||
EXTRACTION_TOKEN_ESTIMATE.labels(direction="output").inc(len(final.raw_output) // 4)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user