Files
stonks-oracle/services/extractor/schemas.py
T

317 lines
11 KiB
Python

"""JSON schema definitions for document intelligence extraction.
Generates Ollama-compatible JSON schemas from Pydantic models so the
extraction contract stays in sync with the shared data models. Also
provides schema validation and semantic validation helpers.
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
import json
import re
from typing import Any
from pydantic import BaseModel, Field
from services.shared.schemas import (
CatalystType,
Sentiment,
)
SCHEMA_VERSION = "2.0.0"
# ---------------------------------------------------------------------------
# Pydantic model that mirrors the Ollama extraction output contract.
# This is the *response* shape we ask the model to produce — it intentionally
# omits server-side fields like document_id, source_credibility, and model
# metadata that are attached after extraction.
# ---------------------------------------------------------------------------
class CompanyExtractionItem(BaseModel):
"""Per-company extraction output expected from the model.
All fields are required (no defaults) so the generated JSON schema
forces the model to produce every field explicitly.
"""
ticker: str = Field(description="Stock ticker symbol mentioned in the document.")
company_name: str = Field(description="Full company name as referenced in the document.")
relevance: float = Field(
ge=0,
le=1,
description="How relevant the document is to this company. 0=tangential, 1=primary subject.",
)
sentiment: Sentiment = Field(description="Overall sentiment toward this company in the document.")
impact_score: float = Field(
ge=0,
le=1,
description="Estimated magnitude of impact. 0=negligible, 1=highly material.",
)
impact_horizon: str = Field(
description="One of: intraday, 1d, 1d_7d, 1d_30d, 30d_90d, 90d_plus",
)
catalyst_type: CatalystType = Field(description="Primary catalyst category.")
key_facts: list[str] = Field(
description="Facts explicitly stated in the document. Do NOT infer or fabricate.",
)
risks: list[str] = Field(
description="Risks explicitly mentioned in the document.",
)
evidence_spans: list[str] = Field(
description="Short verbatim quotes from the document supporting the analysis.",
)
class ExtractionResult(BaseModel):
"""Top-level structured output the model must return.
All fields are required (no defaults) so the generated JSON schema
forces the model to produce every field explicitly.
"""
summary: str = Field(
description="A concise 1-3 sentence summary of the document's main point.",
)
companies: list[CompanyExtractionItem] = Field(
description="Per-company intelligence extracted from the document.",
)
macro_themes: list[str] = Field(
description="Broad economic or market themes mentioned (e.g. rates, inflation, ai_capex).",
)
novelty_score: float = Field(
ge=0,
le=1,
description="How novel or surprising the information is. 0=routine, 1=highly novel.",
)
confidence: float = Field(
ge=0,
le=1,
description="Model confidence in the accuracy of this extraction. Lower if text is ambiguous.",
)
extraction_warnings: list[str] = Field(
description="Any issues encountered: ambiguous_ticker, incomplete_text, low_confidence, etc.",
)
# ---------------------------------------------------------------------------
# Schema generation
# ---------------------------------------------------------------------------
def generate_json_schema() -> dict[str, Any]:
"""Generate the JSON schema from the Pydantic model.
Returns a plain JSON Schema dict suitable for Ollama's ``format``
parameter. Pydantic ``$defs`` are inlined so the schema is
self-contained.
"""
raw = ExtractionResult.model_json_schema()
# Inline $defs so the schema is flat and Ollama-friendly
return _inline_defs(raw)
def get_schema_version() -> str:
"""Return the current schema version string."""
return SCHEMA_VERSION
# ---------------------------------------------------------------------------
# Validation helpers
# ---------------------------------------------------------------------------
class ValidationReport(BaseModel):
"""Result of validating a raw model response."""
valid: bool = False
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
parsed: ExtractionResult | None = None
def validate_extraction(
raw_json: str | dict[str, Any],
*,
document_text: str = "",
) -> ValidationReport:
"""Validate raw model output against the extraction schema.
Performs structural (JSON / Pydantic) validation followed by semantic
checks that catch hallucination indicators, cross-field inconsistencies,
and data-quality issues.
Args:
raw_json: Either a JSON string or an already-parsed dict.
document_text: Optional original document text used for evidence
span verification.
Returns:
A ``ValidationReport`` with parsed result on success.
"""
errors: list[str] = []
warnings: list[str] = []
# --- Parse JSON string if needed ---
if isinstance(raw_json, str):
try:
data = json.loads(raw_json)
except json.JSONDecodeError as exc:
return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"])
else:
data = raw_json
if not isinstance(data, dict):
return ValidationReport(valid=False, errors=["Expected a JSON object at top level."])
# --- Pydantic structural validation ---
try:
result = ExtractionResult.model_validate(data)
except Exception as exc: # noqa: BLE001
return ValidationReport(valid=False, errors=[f"Schema validation failed: {exc}"])
# --- Semantic checks ---
sem_errors, sem_warnings = _semantic_checks(result, document_text)
errors.extend(sem_errors)
warnings.extend(sem_warnings)
# Semantic errors make the report invalid — the caller should retry.
valid = len(errors) == 0
return ValidationReport(
valid=valid,
errors=errors,
warnings=warnings,
parsed=result,
)
# ---------------------------------------------------------------------------
# Known valid impact horizons
# ---------------------------------------------------------------------------
VALID_IMPACT_HORIZONS = frozenset({
"intraday",
"1d",
"1d_7d",
"1d_30d",
"30d_90d",
"90d_plus",
})
# Ticker: 1-5 uppercase letters (covers NYSE, NASDAQ, etc.)
_TICKER_RE = re.compile(r"^[A-Z]{1,5}$")
# Evidence span length bounds (characters)
_MIN_EVIDENCE_LEN = 8
_MAX_EVIDENCE_LEN = 500
# ---------------------------------------------------------------------------
# Semantic validation rules
# ---------------------------------------------------------------------------
def _semantic_checks(
result: ExtractionResult,
document_text: str = "",
) -> tuple[list[str], list[str]]:
"""Run semantic checks on a parsed extraction.
Returns a tuple of (errors, warnings). Errors are issues severe enough
to warrant a retry; warnings are informational.
"""
errors: list[str] = []
warnings: list[str] = []
# --- Top-level checks ---
if not result.summary:
warnings.append("empty_summary")
if result.confidence < 0.3 and len(result.companies) > 0:
warnings.append("low_confidence_with_companies")
# Duplicate tickers across company entries
tickers_seen: list[str] = []
for comp in result.companies:
if comp.ticker in tickers_seen:
errors.append(f"duplicate_ticker_{comp.ticker}")
tickers_seen.append(comp.ticker)
# --- Per-company checks ---
for comp in result.companies:
tag = comp.ticker or "unknown"
# Ticker format
if not comp.ticker:
errors.append("company_missing_ticker")
elif not _TICKER_RE.match(comp.ticker):
warnings.append(f"invalid_ticker_format_{tag}")
# Impact horizon must be a known value
if comp.impact_horizon not in VALID_IMPACT_HORIZONS:
errors.append(f"invalid_impact_horizon_{comp.impact_horizon}_for_{tag}")
# Evidence spans
if not comp.evidence_spans:
warnings.append(f"no_evidence_spans_for_{tag}")
else:
for idx, span in enumerate(comp.evidence_spans):
if len(span) < _MIN_EVIDENCE_LEN:
warnings.append(f"evidence_span_too_short_for_{tag}_{idx}")
if len(span) > _MAX_EVIDENCE_LEN:
warnings.append(f"evidence_span_too_long_for_{tag}_{idx}")
# Cross-field: high impact but no facts
if not comp.key_facts and comp.impact_score > 0.5:
warnings.append(f"high_impact_no_facts_for_{tag}")
# Cross-field: very low relevance
if comp.relevance < 0.2:
warnings.append(f"very_low_relevance_for_{tag}")
# Cross-field: strong sentiment but low impact
if comp.sentiment in (Sentiment.POSITIVE, Sentiment.NEGATIVE) and comp.impact_score < 0.1:
warnings.append(f"strong_sentiment_low_impact_for_{tag}")
# --- Evidence grounding check (when source text is available) ---
if document_text:
doc_lower = document_text.lower()
for comp in result.companies:
for idx, span in enumerate(comp.evidence_spans):
if span.lower() not in doc_lower:
warnings.append(
f"evidence_span_not_found_in_document_for_{comp.ticker or 'unknown'}_{idx}"
)
return errors, warnings
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]:
"""Recursively inline ``$defs`` / ``$ref`` so the schema is self-contained."""
defs = schema.pop("$defs", {})
return _resolve_refs(schema, defs)
def _resolve_refs(node: Any, defs: dict[str, Any]) -> Any:
"""Walk the schema tree and replace ``$ref`` pointers with their definitions."""
if isinstance(node, dict):
if "$ref" in node:
ref_path = node["$ref"] # e.g. "#/$defs/CompanyExtractionItem"
ref_name = ref_path.rsplit("/", 1)[-1]
if ref_name in defs:
resolved = defs[ref_name].copy()
# The resolved def may itself contain refs
return _resolve_refs(resolved, defs)
return node # unresolvable ref, leave as-is
return {k: _resolve_refs(v, defs) for k, v in node.items()}
if isinstance(node, list):
return [_resolve_refs(item, defs) for item in node]
return node