phase 14-15: docker build validation and helm deployment

This commit is contained in:
Celes Renata
2026-04-11 11:59:45 -07:00
parent 7394d241c9
commit ce10afa034
179 changed files with 32559 additions and 576 deletions
+268
View File
@@ -0,0 +1,268 @@
"""Ollama client wrapper using structured output format.
Sends documents to a local Ollama instance via the /api/chat endpoint
with the ``format`` parameter set to the extraction JSON schema, ensuring
the model returns schema-compliant JSON.
Includes retry logic for invalid or incomplete model responses with
exponential backoff, error classification, and full audit preservation.
Requirements: 5.1, 5.2, 5.4
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
import httpx
from services.extractor.prompts import (
build_extraction_prompt,
get_json_schema,
get_prompt_metadata,
)
from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction
from services.shared.config import OllamaConfig
logger = logging.getLogger("ollama_client")
# Errors that should NOT be retried — the request itself is bad.
_NON_RETRYABLE_ERRORS = frozenset({
"http_400",
"http_401",
"http_403",
"http_404",
"http_422",
})
def _is_retryable(error: str | None) -> bool:
"""Determine whether an extraction error warrants a retry."""
if error is None:
return False
return error not in _NON_RETRYABLE_ERRORS
@dataclass
class ExtractionAttempt:
"""Record of a single extraction attempt for audit."""
raw_output: str = ""
validation: ValidationReport | None = None
error: str | None = None
duration_ms: int = 0
model: str = ""
retryable: bool = True
@dataclass
class ExtractionResponse:
"""Full response from an extraction call, including all attempts."""
success: bool = False
result: ExtractionResult | None = None
attempts: list[ExtractionAttempt] = field(default_factory=list)
prompt_metadata: dict[str, str] = field(default_factory=dict)
model: str = ""
total_duration_ms: int = 0
def _compute_backoff(
attempt_num: int,
base_delay: float,
max_delay: float,
multiplier: float,
) -> float:
"""Compute exponential backoff delay for a given attempt number."""
delay = base_delay * (multiplier ** attempt_num)
return min(delay, max_delay)
class OllamaClient:
"""Async client for Ollama structured extraction.
Usage::
config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b")
client = OllamaClient(config)
response = await client.extract(
document_text="Apple reported record earnings...",
document_type="article",
document_id="abc-123",
)
if response.success:
print(response.result)
"""
_config: OllamaConfig
_max_retries: int
_base_delay: float
_max_delay: float
_backoff_multiplier: float
_owns_client: bool
_http: httpx.AsyncClient
def __init__(
self,
config: OllamaConfig,
max_retries: int | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
self._config = config
self._max_retries = max_retries if max_retries is not None else config.max_retries
self._base_delay = config.retry_base_delay
self._max_delay = config.retry_max_delay
self._backoff_multiplier = config.retry_backoff_multiplier
self._owns_client = http_client is None
self._http = http_client or httpx.AsyncClient(timeout=config.timeout)
async def close(self) -> None:
"""Close the underlying HTTP client if we own it."""
if self._owns_client:
await self._http.aclose()
async def extract(
self,
document_text: str,
document_type: str = "article",
document_id: str = "",
known_tickers: list[str] | None = None,
) -> ExtractionResponse:
"""Send a document to Ollama for structured intelligence extraction.
Retries up to ``max_retries`` times when the model returns invalid
or incomplete JSON. Uses exponential backoff between retries.
Non-retryable errors (e.g. HTTP 400) stop retries immediately.
Each attempt and its validation result are preserved for audit.
Args:
document_text: Normalized text content of the document.
document_type: One of article, filing, transcript, press_release.
document_id: Optional document ID for traceability.
known_tickers: Optional ticker hints for the model.
Returns:
An ``ExtractionResponse`` with the parsed result on success.
"""
prompts = build_extraction_prompt(
document_text=document_text,
document_type=document_type,
document_id=document_id,
known_tickers=known_tickers,
)
json_schema = get_json_schema()
prompt_meta = get_prompt_metadata()
response = ExtractionResponse(
prompt_metadata=prompt_meta,
model=self._config.model,
)
total_start = time.monotonic()
for attempt_num in range(self._max_retries + 1):
attempt = await self._call_ollama(prompts, json_schema, document_text)
response.attempts.append(attempt)
if attempt.error is None and attempt.validation and attempt.validation.valid:
response.success = True
response.result = attempt.validation.parsed
break
# Check if the error is non-retryable — stop immediately
if not _is_retryable(attempt.error):
attempt.retryable = False
logger.warning(
"Non-retryable error for doc %s: %s — stopping retries",
document_id or "unknown",
attempt.error,
)
break
if attempt_num < self._max_retries:
delay = _compute_backoff(
attempt_num,
self._base_delay,
self._max_delay,
self._backoff_multiplier,
)
logger.warning(
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1,
self._max_retries + 1,
document_id or "unknown",
attempt.error or "validation failed",
delay,
)
await asyncio.sleep(delay)
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
return response
async def _call_ollama(
self,
prompts: dict[str, str],
json_schema: dict[str, object],
document_text: str = "",
) -> ExtractionAttempt:
"""Make a single call to the Ollama /api/chat endpoint."""
attempt = ExtractionAttempt(model=self._config.model)
start = time.monotonic()
payload = {
"model": self._config.model,
"messages": [
{"role": "system", "content": prompts["system"]},
{"role": "user", "content": prompts["user"]},
],
"format": json_schema,
"stream": False,
}
try:
resp = await self._http.post(
f"{self._config.base_url}/api/chat",
json=payload,
)
_ = resp.raise_for_status()
except httpx.TimeoutException:
attempt.error = "timeout"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPStatusError as exc:
attempt.error = f"http_{exc.response.status_code}"
attempt.retryable = _is_retryable(attempt.error)
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPError as exc:
attempt.error = f"connection_error: {exc}"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
attempt.duration_ms = int((time.monotonic() - start) * 1000)
# Parse the Ollama response envelope
try:
body: dict[str, object] = resp.json()
except json.JSONDecodeError:
attempt.error = "invalid_response_json"
attempt.raw_output = resp.text
return attempt
msg = body.get("message")
content: str = msg.get("content", "") if isinstance(msg, dict) else ""
attempt.raw_output = content
if not content:
attempt.error = "empty_model_response"
return attempt
# Validate against extraction schema
attempt.validation = validate_extraction(content, document_text=document_text)
if not attempt.validation.valid:
attempt.error = "; ".join(attempt.validation.errors)
return attempt
+72
View File
@@ -0,0 +1,72 @@
"""Extractor worker entrypoint - polls Redis for extraction jobs."""
from __future__ import annotations
import asyncio
import logging
import asyncpg
from minio import Minio
from services.extractor.client import OllamaClient
from services.extractor.worker import persist_extraction
from services.shared.config import load_config
from services.shared.logging import setup_logging
from services.shared.redis_keys import QUEUE_EXTRACTION, queue_key
logger = logging.getLogger("extractor_main")
async def main() -> None:
config = load_config()
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
minio_client = Minio(
config.minio.endpoint,
access_key=config.minio.access_key,
secret_key=config.minio.secret_key,
secure=config.minio.secure,
)
ollama = OllamaClient(config.ollama)
import json
import redis.asyncio as aioredis
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION)
logger.info("Extractor worker started, polling %s", queue)
try:
while True:
raw = await redis_client.lpop(queue)
if raw is None:
await asyncio.sleep(1)
continue
payload = raw
job = json.loads(payload)
document_id = job.get("document_id", "")
ticker = job.get("ticker", "")
text = job.get("text", "")
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
try:
extraction_response = await ollama.extract(text)
await persist_extraction(
pool=pool,
minio_client=minio_client,
document_id=document_id,
ticker=ticker,
extraction_response=extraction_response,
document_text_length=len(text),
)
except Exception:
logger.exception("Extraction failed for doc %s", document_id)
finally:
await pool.close()
await redis_client.close()
if __name__ == "__main__":
asyncio.run(main())
+250
View File
@@ -0,0 +1,250 @@
"""Model performance metrics collection and persistence.
Tracks extraction success/failure rates, latency percentiles, retry counts,
validation error distributions, confidence scores, and token usage estimates.
Metrics are persisted to PostgreSQL for operational dashboards and published
to the analytical lake for Trino/Superset queries.
Requirements: 5.2, 5.4, 12.1, 12.2
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.extractor.client import ExtractionResponse
logger = logging.getLogger("extractor_metrics")
# Rough token estimate: ~4 chars per token for English text
_CHARS_PER_TOKEN = 4
@dataclass
class ExtractionMetrics:
"""Metrics extracted from a single extraction run."""
document_id: str = ""
ticker: str = ""
model_name: str = ""
prompt_version: str = ""
schema_version: str = ""
success: bool = False
attempt_count: int = 0
total_duration_ms: int = 0
first_attempt_duration_ms: int = 0
final_attempt_duration_ms: int = 0
confidence: float = 0.0
validation_status: str = "unknown"
validation_error_count: int = 0
validation_warning_count: int = 0
validation_errors: list[str] = field(default_factory=list)
retry_count: int = 0
input_token_estimate: int = 0
output_token_estimate: int = 0
company_count: int = 0
recorded_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def collect_metrics(
extraction_response: ExtractionResponse,
*,
document_id: str = "",
ticker: str = "",
document_text_length: int = 0,
) -> ExtractionMetrics:
"""Collect metrics from an ExtractionResponse.
Args:
extraction_response: The full response from OllamaClient.extract().
document_id: UUID of the source document.
ticker: Primary ticker symbol.
document_text_length: Length of the input document text in characters.
Returns:
An ExtractionMetrics dataclass with all computed fields.
"""
attempts = extraction_response.attempts
first_dur = attempts[0].duration_ms if attempts else 0
final_dur = attempts[-1].duration_ms if attempts else 0
# Gather validation info from the final attempt
final_attempt = attempts[-1] if attempts else None
val_errors: list[str] = []
val_warnings: list[str] = []
if final_attempt and final_attempt.validation:
val_errors = final_attempt.validation.errors
val_warnings = final_attempt.validation.warnings
# Determine validation status
if extraction_response.success:
validation_status = "valid"
elif attempts:
validation_status = "failed"
else:
validation_status = "unknown"
# Confidence from the result, or 0 if failed
confidence = 0.0
company_count = 0
if extraction_response.result:
confidence = extraction_response.result.confidence
company_count = len(extraction_response.result.companies)
# Token estimates
input_tokens = document_text_length // _CHARS_PER_TOKEN if document_text_length > 0 else 0
output_tokens = 0
if final_attempt and final_attempt.raw_output:
output_tokens = len(final_attempt.raw_output) // _CHARS_PER_TOKEN
return ExtractionMetrics(
document_id=document_id,
ticker=ticker,
model_name=extraction_response.model,
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
success=extraction_response.success,
attempt_count=len(attempts),
total_duration_ms=extraction_response.total_duration_ms,
first_attempt_duration_ms=first_dur,
final_attempt_duration_ms=final_dur,
confidence=confidence,
validation_status=validation_status,
validation_error_count=len(val_errors),
validation_warning_count=len(val_warnings),
validation_errors=val_errors,
retry_count=max(0, len(attempts) - 1),
input_token_estimate=input_tokens,
output_token_estimate=output_tokens,
company_count=company_count,
)
async def persist_metrics(
pool: asyncpg.Pool,
metrics: ExtractionMetrics,
) -> str:
"""Persist extraction metrics to the model_performance_metrics table.
Args:
pool: PostgreSQL connection pool.
metrics: Collected metrics from an extraction run.
Returns:
The UUID of the inserted metrics row.
"""
row_id = await pool.fetchval(
"""INSERT INTO model_performance_metrics
(document_id, ticker, model_name, prompt_version, schema_version,
success, attempt_count, total_duration_ms,
first_attempt_duration_ms, final_attempt_duration_ms,
confidence, validation_status, validation_error_count,
validation_warning_count, validation_errors, retry_count,
input_token_estimate, output_token_estimate, company_count,
recorded_at)
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15::jsonb, $16, $17, $18, $19, $20)
RETURNING id""",
metrics.document_id,
metrics.ticker,
metrics.model_name,
metrics.prompt_version,
metrics.schema_version,
metrics.success,
metrics.attempt_count,
metrics.total_duration_ms,
metrics.first_attempt_duration_ms,
metrics.final_attempt_duration_ms,
metrics.confidence,
metrics.validation_status,
metrics.validation_error_count,
metrics.validation_warning_count,
json.dumps(metrics.validation_errors),
metrics.retry_count,
metrics.input_token_estimate,
metrics.output_token_estimate,
metrics.company_count,
metrics.recorded_at,
)
logger.info(
"Persisted extraction metrics %s for doc %s: success=%s duration=%dms retries=%d",
row_id, metrics.document_id, metrics.success,
metrics.total_duration_ms, metrics.retry_count,
)
return str(row_id)
async def get_model_performance_summary(
pool: asyncpg.Pool,
*,
model_name: str | None = None,
hours: int = 24,
) -> dict[str, object]:
"""Query aggregated model performance metrics for dashboards.
Returns a summary dict with success rate, avg latency, retry rate,
confidence distribution, and error breakdown for the given time window.
Args:
pool: PostgreSQL connection pool.
model_name: Optional filter by model name.
hours: Lookback window in hours (default 24).
Returns:
Dict with aggregated performance metrics.
"""
model_filter = "AND model_name = $2" if model_name else ""
params: list[object] = [hours]
if model_name:
params.append(model_name)
row = await pool.fetchrow(
f"""SELECT
COUNT(*) AS total_extractions,
COUNT(*) FILTER (WHERE success) AS successful,
COUNT(*) FILTER (WHERE NOT success) AS failed,
ROUND(AVG(total_duration_ms)::numeric, 1) AS avg_duration_ms,
ROUND(PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p50_duration_ms,
ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p95_duration_ms,
ROUND(PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY total_duration_ms)::numeric, 1) AS p99_duration_ms,
ROUND(AVG(retry_count)::numeric, 2) AS avg_retries,
ROUND(AVG(confidence)::numeric, 3) AS avg_confidence,
SUM(input_token_estimate) AS total_input_tokens,
SUM(output_token_estimate) AS total_output_tokens,
ROUND(AVG(company_count)::numeric, 2) AS avg_companies_per_doc,
ROUND(AVG(validation_error_count)::numeric, 2) AS avg_validation_errors,
ROUND(AVG(validation_warning_count)::numeric, 2) AS avg_validation_warnings
FROM model_performance_metrics
WHERE recorded_at >= NOW() - INTERVAL '1 hour' * $1
{model_filter}""",
*params,
)
if not row or row["total_extractions"] == 0:
return {"total_extractions": 0, "success_rate": 0.0}
total = row["total_extractions"]
successful = row["successful"]
return {
"total_extractions": total,
"successful": successful,
"failed": row["failed"],
"success_rate": round(successful / total, 4) if total > 0 else 0.0,
"avg_duration_ms": float(row["avg_duration_ms"] or 0),
"p50_duration_ms": float(row["p50_duration_ms"] or 0),
"p95_duration_ms": float(row["p95_duration_ms"] or 0),
"p99_duration_ms": float(row["p99_duration_ms"] or 0),
"avg_retries": float(row["avg_retries"] or 0),
"avg_confidence": float(row["avg_confidence"] or 0),
"total_input_tokens": int(row["total_input_tokens"] or 0),
"total_output_tokens": int(row["total_output_tokens"] or 0),
"avg_companies_per_doc": float(row["avg_companies_per_doc"] or 0),
"avg_validation_errors": float(row["avg_validation_errors"] or 0),
"avg_validation_warnings": float(row["avg_validation_warnings"] or 0),
"hours": hours,
}
+149
View File
@@ -0,0 +1,149 @@
"""Extraction prompt templates with anti-hallucination instructions.
Builds structured prompts for Ollama document intelligence extraction.
Each prompt includes the target JSON schema, anti-hallucination rules,
and document-type-specific guidance.
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
import json
from typing import Any
from services.extractor.schemas import generate_json_schema, SCHEMA_VERSION
from services.shared.schemas import (
DocumentType,
)
PROMPT_VERSION = "document-intel-v1"
# --- JSON schema for structured output (generated from Pydantic models) ---
EXTRACTION_JSON_SCHEMA: dict[str, Any] = generate_json_schema()
# --- Anti-hallucination system prompt ---
SYSTEM_PROMPT = """\
You are a financial document analysis system. You extract structured intelligence \
from financial documents into JSON.
STRICT RULES — VIOLATIONS WILL INVALIDATE YOUR OUTPUT:
1. ONLY extract information explicitly stated in the document text provided.
2. NEVER fabricate facts, quotes, numbers, dates, or company names.
3. NEVER infer information that is not directly supported by the text.
4. If the document does not mention a company, do NOT include that company.
5. If the document is ambiguous about sentiment or impact, use "neutral" or "mixed" \
and set confidence lower.
6. evidence_spans MUST be short verbatim quotes copied from the document. \
Do NOT paraphrase or invent quotes.
7. key_facts MUST be directly stated in the document. Do NOT add external knowledge.
8. If you are uncertain about any field, lower the confidence score and add a warning \
to extraction_warnings.
9. If the document text is too short, garbled, or uninformative, return an empty \
companies array, set confidence below 0.3, and add "insufficient_content" to warnings.
10. Return ONLY valid JSON matching the provided schema. No commentary, no markdown fences."""
# --- Document-type-specific guidance ---
_DOCTYPE_GUIDANCE: dict[str, str] = {
DocumentType.ARTICLE: (
"This is a news article. Focus on reported facts, quoted sources, and stated "
"analyst opinions. Distinguish between the journalist's framing and actual "
"company developments. Do not treat speculative language as confirmed fact."
),
DocumentType.FILING: (
"This is a regulatory filing (e.g. SEC 10-K, 10-Q, 8-K). Extract concrete "
"financial figures, risk factors, and material events as stated. Filings use "
"precise legal language — preserve that precision in your extraction."
),
DocumentType.TRANSCRIPT: (
"This is an earnings call or event transcript. Distinguish between management "
"forward-looking statements and reported results. Flag forward-looking language "
"as lower confidence. Extract specific guidance numbers when stated."
),
DocumentType.PRESS_RELEASE: (
"This is a company press release. Be aware that press releases are promotional. "
"Extract stated facts and figures but note that sentiment may be biased positive. "
"Look for concrete metrics rather than marketing language."
),
}
def _get_doctype_guidance(document_type: str) -> str:
"""Return document-type-specific extraction guidance."""
return _DOCTYPE_GUIDANCE.get(document_type, _DOCTYPE_GUIDANCE[DocumentType.ARTICLE])
# --- Prompt builder ---
def build_extraction_prompt(
document_text: str,
document_type: str = DocumentType.ARTICLE,
known_tickers: list[str] | None = None,
document_id: str = "",
) -> dict[str, str]:
"""Build system and user prompts for Ollama structured extraction.
Args:
document_text: Normalized text content of the document.
document_type: One of the DocumentType enum values.
known_tickers: Optional list of tickers the document may reference.
Helps the model focus but does NOT mean all tickers are relevant.
document_id: Optional document ID for traceability.
Returns:
Dict with 'system' and 'user' prompt strings.
"""
doctype_guidance = _get_doctype_guidance(document_type)
ticker_hint = ""
if known_tickers:
tickers_str = ", ".join(known_tickers)
ticker_hint = (
f"\nThe following tickers may be referenced in this document: {tickers_str}\n"
"Only include a ticker in your output if the document actually discusses that company. "
"Do NOT include a ticker just because it appears in this hint."
)
schema_str = json.dumps(EXTRACTION_JSON_SCHEMA, indent=2)
doc_id_line = f"Document ID: {document_id}\n" if document_id else ""
user_prompt = f"""\
Extract structured intelligence from the following document.
{doc_id_line}Document type: {document_type}
{doctype_guidance}
{ticker_hint}
Your output MUST be a single JSON object conforming to this schema:
{schema_str}
REMEMBER:
- Only extract what is explicitly in the text below.
- evidence_spans must be verbatim quotes from the text.
- If the text is insufficient, return empty companies and low confidence.
- Return ONLY the JSON object. No other text.
--- DOCUMENT TEXT ---
{document_text}
--- END DOCUMENT TEXT ---"""
return {
"system": SYSTEM_PROMPT,
"user": user_prompt,
}
def get_prompt_metadata() -> dict[str, str]:
"""Return metadata about the current prompt version for audit trails."""
return {
"prompt_version": PROMPT_VERSION,
"schema_version": SCHEMA_VERSION,
}
def get_json_schema() -> dict[str, Any]:
"""Return the extraction JSON schema for Ollama structured output format parameter."""
return EXTRACTION_JSON_SCHEMA
+250
View File
@@ -0,0 +1,250 @@
"""Replay dataset loader and runner for deterministic extraction testing.
Loads archived document fixtures from JSON files, validates their expected
extraction outputs against the current schema, and provides a runner that
can compare live Ollama extraction results against expected baselines.
This enables:
- Schema regression testing: verify expected outputs still pass validation
- Prompt regression testing: detect drift when prompts or schemas change
- End-to-end replay: run fixtures through a live Ollama and compare
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from services.extractor.schemas import (
ExtractionResult,
ValidationReport,
get_schema_version,
validate_extraction,
)
logger = logging.getLogger("extractor_replay")
FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent / "tests" / "replay_fixtures"
@dataclass
class ReplayFixture:
"""A single replay fixture loaded from disk."""
document_id: str
document_type: str
document_text: str
known_tickers: list[str]
expected_extraction: dict[str, Any]
metadata: dict[str, str]
source_path: str = ""
@property
def expected_result(self) -> ExtractionResult:
"""Parse expected_extraction into a validated ExtractionResult."""
return ExtractionResult.model_validate(self.expected_extraction)
@dataclass
class ReplayValidationResult:
"""Result of validating a single fixture against the current schema."""
fixture_id: str
schema_valid: bool = False
validation_report: ValidationReport | None = None
schema_version: str = ""
error: str | None = None
@dataclass
class ReplayComparisonResult:
"""Result of comparing a live extraction against the expected baseline."""
fixture_id: str
expected_companies: list[str] = field(default_factory=list)
actual_companies: list[str] = field(default_factory=list)
companies_match: bool = False
expected_sentiment_map: dict[str, str] = field(default_factory=dict)
actual_sentiment_map: dict[str, str] = field(default_factory=dict)
sentiment_match: bool = False
expected_catalyst_map: dict[str, str] = field(default_factory=dict)
actual_catalyst_map: dict[str, str] = field(default_factory=dict)
catalyst_match: bool = False
actual_schema_valid: bool = False
warnings: list[str] = field(default_factory=list)
def load_fixture(path: Path) -> ReplayFixture:
"""Load a single replay fixture from a JSON file.
Args:
path: Path to the fixture JSON file.
Returns:
A ReplayFixture with all fields populated.
Raises:
ValueError: If the fixture is missing required fields.
json.JSONDecodeError: If the file is not valid JSON.
"""
with open(path) as f:
data = json.load(f)
required = {"document_id", "document_type", "document_text", "expected_extraction"}
missing = required - set(data.keys())
if missing:
raise ValueError(f"Fixture {path.name} missing required fields: {missing}")
return ReplayFixture(
document_id=data["document_id"],
document_type=data["document_type"],
document_text=data["document_text"],
known_tickers=data.get("known_tickers", []),
expected_extraction=data["expected_extraction"],
metadata=data.get("metadata", {}),
source_path=str(path),
)
def load_all_fixtures(fixtures_dir: Path | None = None) -> list[ReplayFixture]:
"""Load all replay fixtures from the fixtures directory.
Args:
fixtures_dir: Override path to fixtures directory.
Defaults to tests/replay_fixtures/.
Returns:
List of loaded ReplayFixture objects, sorted by document_id.
"""
directory = fixtures_dir or FIXTURES_DIR
if not directory.is_dir():
logger.warning("Fixtures directory not found: %s", directory)
return []
fixtures: list[ReplayFixture] = []
for path in sorted(directory.glob("*.json")):
try:
fixture = load_fixture(path)
fixtures.append(fixture)
except (ValueError, json.JSONDecodeError) as exc:
logger.warning("Skipping invalid fixture %s: %s", path.name, exc)
logger.info("Loaded %d replay fixtures from %s", len(fixtures), directory)
return fixtures
def validate_fixture(fixture: ReplayFixture) -> ReplayValidationResult:
"""Validate a fixture's expected extraction against the current schema.
This is the core deterministic test: the expected output must still
pass schema and semantic validation with the current code. If it
doesn't, either the fixture is stale or the schema has regressed.
Args:
fixture: The replay fixture to validate.
Returns:
A ReplayValidationResult indicating pass/fail.
"""
result = ReplayValidationResult(
fixture_id=fixture.document_id,
schema_version=get_schema_version(),
)
try:
report = validate_extraction(
fixture.expected_extraction,
document_text=fixture.document_text,
)
result.validation_report = report
result.schema_valid = report.valid
except Exception as exc: # noqa: BLE001
result.error = str(exc)
result.schema_valid = False
return result
def validate_all_fixtures(
fixtures_dir: Path | None = None,
) -> list[ReplayValidationResult]:
"""Load and validate all fixtures against the current schema.
Args:
fixtures_dir: Override path to fixtures directory.
Returns:
List of validation results, one per fixture.
"""
fixtures = load_all_fixtures(fixtures_dir)
return [validate_fixture(f) for f in fixtures]
def compare_extraction(
fixture: ReplayFixture,
actual_result: ExtractionResult,
) -> ReplayComparisonResult:
"""Compare a live extraction result against the fixture's expected output.
Checks structural alignment (same companies detected, same sentiments,
same catalyst types) rather than exact string equality, since LLM
outputs vary in wording across runs.
Args:
fixture: The replay fixture with expected output.
actual_result: The ExtractionResult from a live extraction.
Returns:
A ReplayComparisonResult with match details.
"""
expected = fixture.expected_result
comparison = ReplayComparisonResult(fixture_id=fixture.document_id)
# Company ticker sets
comparison.expected_companies = sorted(c.ticker for c in expected.companies)
comparison.actual_companies = sorted(c.ticker for c in actual_result.companies)
comparison.companies_match = (
set(comparison.expected_companies) == set(comparison.actual_companies)
)
# Sentiment by ticker
comparison.expected_sentiment_map = {
c.ticker: c.sentiment for c in expected.companies
}
comparison.actual_sentiment_map = {
c.ticker: c.sentiment for c in actual_result.companies
}
comparison.sentiment_match = (
comparison.expected_sentiment_map == comparison.actual_sentiment_map
)
# Catalyst type by ticker
comparison.expected_catalyst_map = {
c.ticker: c.catalyst_type for c in expected.companies
}
comparison.actual_catalyst_map = {
c.ticker: c.catalyst_type for c in actual_result.companies
}
comparison.catalyst_match = (
comparison.expected_catalyst_map == comparison.actual_catalyst_map
)
# Schema validity of actual result
actual_report = validate_extraction(
actual_result.model_dump(mode="json"),
document_text=fixture.document_text,
)
comparison.actual_schema_valid = actual_report.valid
if actual_report.warnings:
comparison.warnings = actual_report.warnings
if not comparison.companies_match:
comparison.warnings.append(
f"company_mismatch: expected={comparison.expected_companies} actual={comparison.actual_companies}"
)
return comparison
+316
View File
@@ -0,0 +1,316 @@
"""JSON schema definitions for document intelligence extraction.
Generates Ollama-compatible JSON schemas from Pydantic models so the
extraction contract stays in sync with the shared data models. Also
provides schema validation and semantic validation helpers.
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
import json
import re
from typing import Any
from pydantic import BaseModel, Field
from services.shared.schemas import (
CatalystType,
Sentiment,
)
SCHEMA_VERSION = "2.0.0"
# ---------------------------------------------------------------------------
# Pydantic model that mirrors the Ollama extraction output contract.
# This is the *response* shape we ask the model to produce — it intentionally
# omits server-side fields like document_id, source_credibility, and model
# metadata that are attached after extraction.
# ---------------------------------------------------------------------------
class CompanyExtractionItem(BaseModel):
"""Per-company extraction output expected from the model.
All fields are required (no defaults) so the generated JSON schema
forces the model to produce every field explicitly.
"""
ticker: str = Field(description="Stock ticker symbol mentioned in the document.")
company_name: str = Field(description="Full company name as referenced in the document.")
relevance: float = Field(
ge=0,
le=1,
description="How relevant the document is to this company. 0=tangential, 1=primary subject.",
)
sentiment: Sentiment = Field(description="Overall sentiment toward this company in the document.")
impact_score: float = Field(
ge=0,
le=1,
description="Estimated magnitude of impact. 0=negligible, 1=highly material.",
)
impact_horizon: str = Field(
description="One of: intraday, 1d, 1d_7d, 1d_30d, 30d_90d, 90d_plus",
)
catalyst_type: CatalystType = Field(description="Primary catalyst category.")
key_facts: list[str] = Field(
description="Facts explicitly stated in the document. Do NOT infer or fabricate.",
)
risks: list[str] = Field(
description="Risks explicitly mentioned in the document.",
)
evidence_spans: list[str] = Field(
description="Short verbatim quotes from the document supporting the analysis.",
)
class ExtractionResult(BaseModel):
"""Top-level structured output the model must return.
All fields are required (no defaults) so the generated JSON schema
forces the model to produce every field explicitly.
"""
summary: str = Field(
description="A concise 1-3 sentence summary of the document's main point.",
)
companies: list[CompanyExtractionItem] = Field(
description="Per-company intelligence extracted from the document.",
)
macro_themes: list[str] = Field(
description="Broad economic or market themes mentioned (e.g. rates, inflation, ai_capex).",
)
novelty_score: float = Field(
ge=0,
le=1,
description="How novel or surprising the information is. 0=routine, 1=highly novel.",
)
confidence: float = Field(
ge=0,
le=1,
description="Model confidence in the accuracy of this extraction. Lower if text is ambiguous.",
)
extraction_warnings: list[str] = Field(
description="Any issues encountered: ambiguous_ticker, incomplete_text, low_confidence, etc.",
)
# ---------------------------------------------------------------------------
# Schema generation
# ---------------------------------------------------------------------------
def generate_json_schema() -> dict[str, Any]:
"""Generate the JSON schema from the Pydantic model.
Returns a plain JSON Schema dict suitable for Ollama's ``format``
parameter. Pydantic ``$defs`` are inlined so the schema is
self-contained.
"""
raw = ExtractionResult.model_json_schema()
# Inline $defs so the schema is flat and Ollama-friendly
return _inline_defs(raw)
def get_schema_version() -> str:
"""Return the current schema version string."""
return SCHEMA_VERSION
# ---------------------------------------------------------------------------
# Validation helpers
# ---------------------------------------------------------------------------
class ValidationReport(BaseModel):
"""Result of validating a raw model response."""
valid: bool = False
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
parsed: ExtractionResult | None = None
def validate_extraction(
raw_json: str | dict[str, Any],
*,
document_text: str = "",
) -> ValidationReport:
"""Validate raw model output against the extraction schema.
Performs structural (JSON / Pydantic) validation followed by semantic
checks that catch hallucination indicators, cross-field inconsistencies,
and data-quality issues.
Args:
raw_json: Either a JSON string or an already-parsed dict.
document_text: Optional original document text used for evidence
span verification.
Returns:
A ``ValidationReport`` with parsed result on success.
"""
errors: list[str] = []
warnings: list[str] = []
# --- Parse JSON string if needed ---
if isinstance(raw_json, str):
try:
data = json.loads(raw_json)
except json.JSONDecodeError as exc:
return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"])
else:
data = raw_json
if not isinstance(data, dict):
return ValidationReport(valid=False, errors=["Expected a JSON object at top level."])
# --- Pydantic structural validation ---
try:
result = ExtractionResult.model_validate(data)
except Exception as exc: # noqa: BLE001
return ValidationReport(valid=False, errors=[f"Schema validation failed: {exc}"])
# --- Semantic checks ---
sem_errors, sem_warnings = _semantic_checks(result, document_text)
errors.extend(sem_errors)
warnings.extend(sem_warnings)
# Semantic errors make the report invalid — the caller should retry.
valid = len(errors) == 0
return ValidationReport(
valid=valid,
errors=errors,
warnings=warnings,
parsed=result,
)
# ---------------------------------------------------------------------------
# Known valid impact horizons
# ---------------------------------------------------------------------------
VALID_IMPACT_HORIZONS = frozenset({
"intraday",
"1d",
"1d_7d",
"1d_30d",
"30d_90d",
"90d_plus",
})
# Ticker: 1-5 uppercase letters (covers NYSE, NASDAQ, etc.)
_TICKER_RE = re.compile(r"^[A-Z]{1,5}$")
# Evidence span length bounds (characters)
_MIN_EVIDENCE_LEN = 8
_MAX_EVIDENCE_LEN = 500
# ---------------------------------------------------------------------------
# Semantic validation rules
# ---------------------------------------------------------------------------
def _semantic_checks(
result: ExtractionResult,
document_text: str = "",
) -> tuple[list[str], list[str]]:
"""Run semantic checks on a parsed extraction.
Returns a tuple of (errors, warnings). Errors are issues severe enough
to warrant a retry; warnings are informational.
"""
errors: list[str] = []
warnings: list[str] = []
# --- Top-level checks ---
if not result.summary:
warnings.append("empty_summary")
if result.confidence < 0.3 and len(result.companies) > 0:
warnings.append("low_confidence_with_companies")
# Duplicate tickers across company entries
tickers_seen: list[str] = []
for comp in result.companies:
if comp.ticker in tickers_seen:
errors.append(f"duplicate_ticker_{comp.ticker}")
tickers_seen.append(comp.ticker)
# --- Per-company checks ---
for comp in result.companies:
tag = comp.ticker or "unknown"
# Ticker format
if not comp.ticker:
errors.append("company_missing_ticker")
elif not _TICKER_RE.match(comp.ticker):
warnings.append(f"invalid_ticker_format_{tag}")
# Impact horizon must be a known value
if comp.impact_horizon not in VALID_IMPACT_HORIZONS:
errors.append(f"invalid_impact_horizon_{comp.impact_horizon}_for_{tag}")
# Evidence spans
if not comp.evidence_spans:
warnings.append(f"no_evidence_spans_for_{tag}")
else:
for idx, span in enumerate(comp.evidence_spans):
if len(span) < _MIN_EVIDENCE_LEN:
warnings.append(f"evidence_span_too_short_for_{tag}_{idx}")
if len(span) > _MAX_EVIDENCE_LEN:
warnings.append(f"evidence_span_too_long_for_{tag}_{idx}")
# Cross-field: high impact but no facts
if not comp.key_facts and comp.impact_score > 0.5:
warnings.append(f"high_impact_no_facts_for_{tag}")
# Cross-field: very low relevance
if comp.relevance < 0.2:
warnings.append(f"very_low_relevance_for_{tag}")
# Cross-field: strong sentiment but low impact
if comp.sentiment in (Sentiment.POSITIVE, Sentiment.NEGATIVE) and comp.impact_score < 0.1:
warnings.append(f"strong_sentiment_low_impact_for_{tag}")
# --- Evidence grounding check (when source text is available) ---
if document_text:
doc_lower = document_text.lower()
for comp in result.companies:
for idx, span in enumerate(comp.evidence_spans):
if span.lower() not in doc_lower:
warnings.append(
f"evidence_span_not_found_in_document_for_{comp.ticker or 'unknown'}_{idx}"
)
return errors, warnings
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]:
"""Recursively inline ``$defs`` / ``$ref`` so the schema is self-contained."""
defs = schema.pop("$defs", {})
return _resolve_refs(schema, defs)
def _resolve_refs(node: Any, defs: dict[str, Any]) -> Any:
"""Walk the schema tree and replace ``$ref`` pointers with their definitions."""
if isinstance(node, dict):
if "$ref" in node:
ref_path = node["$ref"] # e.g. "#/$defs/CompanyExtractionItem"
ref_name = ref_path.rsplit("/", 1)[-1]
if ref_name in defs:
resolved = defs[ref_name].copy()
# The resolved def may itself contain refs
return _resolve_refs(resolved, defs)
return node # unresolvable ref, leave as-is
return {k: _resolve_refs(v, defs) for k, v in node.items()}
if isinstance(node, list):
return [_resolve_refs(item, defs) for item in node]
return node
+291 -1
View File
@@ -1 +1,291 @@
"""Extraction worker - sends documents to Ollama for structured intelligence extraction."""
"""Extraction worker - sends documents to Ollama for structured intelligence extraction.
Orchestrates the full extraction pipeline for a single document:
1. Calls OllamaClient to get structured extraction
2. Uploads prompts, raw outputs, and validation reports to MinIO
3. Persists the final intelligence object and per-company impact records to PostgreSQL
4. Updates document status
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 9.1, 9.2
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from datetime import datetime, timezone
import asyncpg
from minio import Minio
from services.extractor.client import ExtractionResponse
from services.extractor.metrics import collect_metrics, persist_metrics
from services.shared.metadata import (
persist_document_impact,
persist_document_intelligence,
update_document_status,
)
from services.shared.storage import (
upload_extraction_intelligence,
upload_extraction_prompt,
upload_extraction_raw_output,
upload_extraction_validation,
)
from services.shared.logging import Span
from services.shared.metrics import (
EXTRACTION_ATTEMPTS,
EXTRACTION_CONFIDENCE,
EXTRACTION_DURATION,
EXTRACTION_JOBS_TOTAL,
EXTRACTION_RETRIES,
EXTRACTION_TOKEN_ESTIMATE,
EXTRACTION_VALIDATION_ERRORS,
)
logger = logging.getLogger("extractor_worker")
@dataclass
class ExtractionPersistResult:
"""Result of persisting an extraction to storage and database."""
intelligence_id: str | None = None
prompt_ref: str | None = None
raw_output_ref: str | None = None
validation_ref: str | None = None
intelligence_ref: str | None = None
impact_ids: list[str] | None = None
metrics_id: str | None = None
success: bool = False
async def persist_extraction(
*,
pool: asyncpg.Pool,
minio_client: Minio,
document_id: str,
ticker: str,
extraction_response: ExtractionResponse,
company_id_map: dict[str, str] | None = None,
source_credibility: float = 0.5,
timestamp: datetime | None = None,
document_text_length: int = 0,
) -> ExtractionPersistResult:
"""Persist all extraction artifacts to MinIO and PostgreSQL.
Uploads prompts, raw model outputs, validation reports, and the final
intelligence object to MinIO. Persists the intelligence record and
per-company impact records to PostgreSQL. Updates document status.
Also collects and persists model performance metrics.
Args:
pool: PostgreSQL connection pool.
minio_client: MinIO client.
document_id: UUID of the source document.
ticker: Primary ticker for path construction.
extraction_response: Full response from OllamaClient.extract().
company_id_map: Optional mapping of ticker -> company UUID for impact records.
source_credibility: Credibility score to attach to the intelligence record.
timestamp: Override timestamp for MinIO paths (defaults to UTC now).
document_text_length: Length of the input document text for token estimation.
Returns:
ExtractionPersistResult with references to all persisted artifacts.
"""
ts = timestamp or datetime.now(timezone.utc)
result = ExtractionPersistResult()
company_id_map = company_id_map or {}
# 1. Upload prompt metadata to MinIO
prompt_payload = json.dumps({
"prompt_metadata": extraction_response.prompt_metadata,
"model": extraction_response.model,
}, indent=2).encode()
result.prompt_ref = upload_extraction_prompt(
minio_client, ticker, document_id, prompt_payload, timestamp=ts,
)
# 2. Upload raw outputs for each attempt
attempts_data: list[dict[str, object]] = []
for idx, attempt in enumerate(extraction_response.attempts):
attempt_record: dict[str, object] = {
"attempt_index": idx,
"raw_output": attempt.raw_output,
"error": attempt.error,
"duration_ms": attempt.duration_ms,
"model": attempt.model,
"retryable": attempt.retryable,
}
if attempt.validation:
attempt_record["validation"] = {
"valid": attempt.validation.valid,
"errors": attempt.validation.errors,
"warnings": attempt.validation.warnings,
}
attempts_data.append(attempt_record)
raw_output_payload = json.dumps({
"document_id": document_id,
"attempts": attempts_data,
"total_duration_ms": extraction_response.total_duration_ms,
"success": extraction_response.success,
}, indent=2).encode()
result.raw_output_ref = upload_extraction_raw_output(
minio_client, ticker, document_id, raw_output_payload, timestamp=ts,
)
# 3. Upload validation report
final_attempt = extraction_response.attempts[-1] if extraction_response.attempts else None
validation_payload = json.dumps({
"document_id": document_id,
"success": extraction_response.success,
"attempt_count": len(extraction_response.attempts),
"final_validation": {
"valid": final_attempt.validation.valid if final_attempt and final_attempt.validation else False,
"errors": final_attempt.validation.errors if final_attempt and final_attempt.validation else [],
"warnings": final_attempt.validation.warnings if final_attempt and final_attempt.validation else [],
} if final_attempt else None,
}, indent=2).encode()
result.validation_ref = upload_extraction_validation(
minio_client, ticker, document_id, validation_payload, timestamp=ts,
)
# 4. Determine validation status and persist intelligence
if extraction_response.success and extraction_response.result:
extraction = extraction_response.result
validation_status = "valid"
validation_errors: list[str] = []
# Upload final intelligence object to MinIO
intelligence_payload = json.dumps(
extraction.model_dump(mode="json"), indent=2,
).encode()
result.intelligence_ref = upload_extraction_intelligence(
minio_client, ticker, document_id, intelligence_payload, timestamp=ts,
)
# Persist to PostgreSQL
intel_id = await persist_document_intelligence(
pool,
document_id=document_id,
summary=extraction.summary,
macro_themes=extraction.macro_themes,
novelty_score=extraction.novelty_score,
source_credibility=source_credibility,
extraction_warnings=extraction.extraction_warnings,
confidence=extraction.confidence,
model_provider="ollama",
model_name=extraction_response.model,
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
raw_output_ref=result.raw_output_ref,
prompt_ref=result.prompt_ref,
validation_status=validation_status,
validation_errors=validation_errors,
retry_count=len(extraction_response.attempts) - 1,
)
result.intelligence_id = intel_id
# Persist per-company impact records
result.impact_ids = []
for company in extraction.companies:
cid = company_id_map.get(company.ticker)
if not cid:
logger.warning(
"No company_id for ticker %s in doc %s, skipping impact record",
company.ticker, document_id,
)
continue
impact_id = await persist_document_impact(
pool,
intelligence_id=intel_id,
company_id=cid,
ticker=company.ticker,
relevance=company.relevance,
sentiment=company.sentiment,
impact_score=company.impact_score,
impact_horizon=company.impact_horizon,
catalyst_type=company.catalyst_type,
key_facts=company.key_facts,
risks=company.risks,
evidence_spans=company.evidence_spans,
)
result.impact_ids.append(impact_id)
await update_document_status(pool, document_id=document_id, status="extracted")
result.success = True
logger.info(
"Extraction persisted for doc %s: intel=%s, impacts=%d",
document_id, intel_id, len(result.impact_ids),
)
else:
# Failed extraction — still persist the attempt data
all_errors: list[str] = []
for attempt in extraction_response.attempts:
if attempt.error:
all_errors.append(attempt.error)
intel_id = await persist_document_intelligence(
pool,
document_id=document_id,
summary="",
macro_themes=[],
novelty_score=0.0,
source_credibility=source_credibility,
extraction_warnings=["extraction_failed"],
confidence=0.0,
model_provider="ollama",
model_name=extraction_response.model,
prompt_version=extraction_response.prompt_metadata.get("prompt_version", ""),
schema_version=extraction_response.prompt_metadata.get("schema_version", ""),
raw_output_ref=result.raw_output_ref,
prompt_ref=result.prompt_ref,
validation_status="failed",
validation_errors=all_errors,
retry_count=len(extraction_response.attempts),
)
result.intelligence_id = intel_id
await update_document_status(pool, document_id=document_id, status="extraction_failed")
logger.warning(
"Extraction failed for doc %s after %d attempts: %s",
document_id, len(extraction_response.attempts), "; ".join(all_errors),
)
# Collect and persist model performance metrics
try:
metrics = collect_metrics(
extraction_response,
document_id=document_id,
ticker=ticker,
document_text_length=document_text_length,
)
metrics.recorded_at = ts
metrics_id = await persist_metrics(pool, metrics)
result.metrics_id = metrics_id
except Exception:
logger.exception("Failed to persist extraction metrics for doc %s", document_id)
# Prometheus metrics
EXTRACTION_ATTEMPTS.inc(len(extraction_response.attempts))
EXTRACTION_DURATION.observe(extraction_response.total_duration_ms / 1000.0)
retry_count = max(0, len(extraction_response.attempts) - 1)
if retry_count > 0:
EXTRACTION_RETRIES.inc(retry_count)
if extraction_response.success:
EXTRACTION_JOBS_TOTAL.labels(status="success").inc()
if extraction_response.result:
EXTRACTION_CONFIDENCE.observe(extraction_response.result.confidence)
else:
EXTRACTION_JOBS_TOTAL.labels(status="failed").inc()
# Count validation errors from final attempt
final = extraction_response.attempts[-1] if extraction_response.attempts else None
if final and final.validation and final.validation.errors:
EXTRACTION_VALIDATION_ERRORS.inc(len(final.validation.errors))
# Token estimates
if document_text_length > 0:
EXTRACTION_TOKEN_ESTIMATE.labels(direction="input").inc(document_text_length // 4)
if final and final.raw_output:
EXTRACTION_TOKEN_ESTIMATE.labels(direction="output").inc(len(final.raw_output) // 4)
return result