317 lines
11 KiB
Python
317 lines
11 KiB
Python
"""JSON schema definitions for document intelligence extraction.
|
|
|
|
Generates Ollama-compatible JSON schemas from Pydantic models so the
|
|
extraction contract stays in sync with the shared data models. Also
|
|
provides schema validation and semantic validation helpers.
|
|
|
|
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from typing import Any
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from services.shared.schemas import (
|
|
CatalystType,
|
|
Sentiment,
|
|
)
|
|
|
|
SCHEMA_VERSION = "2.0.0"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pydantic model that mirrors the Ollama extraction output contract.
|
|
# This is the *response* shape we ask the model to produce — it intentionally
|
|
# omits server-side fields like document_id, source_credibility, and model
|
|
# metadata that are attached after extraction.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class CompanyExtractionItem(BaseModel):
|
|
"""Per-company extraction output expected from the model.
|
|
|
|
All fields are required (no defaults) so the generated JSON schema
|
|
forces the model to produce every field explicitly.
|
|
"""
|
|
|
|
ticker: str = Field(description="Stock ticker symbol mentioned in the document.")
|
|
company_name: str = Field(description="Full company name as referenced in the document.")
|
|
relevance: float = Field(
|
|
ge=0,
|
|
le=1,
|
|
description="How relevant the document is to this company. 0=tangential, 1=primary subject.",
|
|
)
|
|
sentiment: Sentiment = Field(description="Overall sentiment toward this company in the document.")
|
|
impact_score: float = Field(
|
|
ge=0,
|
|
le=1,
|
|
description="Estimated magnitude of impact. 0=negligible, 1=highly material.",
|
|
)
|
|
impact_horizon: str = Field(
|
|
description="One of: intraday, 1d, 1d_7d, 1d_30d, 30d_90d, 90d_plus",
|
|
)
|
|
catalyst_type: CatalystType = Field(description="Primary catalyst category.")
|
|
key_facts: list[str] = Field(
|
|
description="Facts explicitly stated in the document. Do NOT infer or fabricate.",
|
|
)
|
|
risks: list[str] = Field(
|
|
description="Risks explicitly mentioned in the document.",
|
|
)
|
|
evidence_spans: list[str] = Field(
|
|
description="Short verbatim quotes from the document supporting the analysis.",
|
|
)
|
|
|
|
|
|
class ExtractionResult(BaseModel):
|
|
"""Top-level structured output the model must return.
|
|
|
|
All fields are required (no defaults) so the generated JSON schema
|
|
forces the model to produce every field explicitly.
|
|
"""
|
|
|
|
summary: str = Field(
|
|
description="A concise 1-3 sentence summary of the document's main point.",
|
|
)
|
|
companies: list[CompanyExtractionItem] = Field(
|
|
description="Per-company intelligence extracted from the document.",
|
|
)
|
|
macro_themes: list[str] = Field(
|
|
description="Broad economic or market themes mentioned (e.g. rates, inflation, ai_capex).",
|
|
)
|
|
novelty_score: float = Field(
|
|
ge=0,
|
|
le=1,
|
|
description="How novel or surprising the information is. 0=routine, 1=highly novel.",
|
|
)
|
|
confidence: float = Field(
|
|
ge=0,
|
|
le=1,
|
|
description="Model confidence in the accuracy of this extraction. Lower if text is ambiguous.",
|
|
)
|
|
extraction_warnings: list[str] = Field(
|
|
description="Any issues encountered: ambiguous_ticker, incomplete_text, low_confidence, etc.",
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schema generation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def generate_json_schema() -> dict[str, Any]:
|
|
"""Generate the JSON schema from the Pydantic model.
|
|
|
|
Returns a plain JSON Schema dict suitable for Ollama's ``format``
|
|
parameter. Pydantic ``$defs`` are inlined so the schema is
|
|
self-contained.
|
|
"""
|
|
raw = ExtractionResult.model_json_schema()
|
|
# Inline $defs so the schema is flat and Ollama-friendly
|
|
return _inline_defs(raw)
|
|
|
|
|
|
def get_schema_version() -> str:
|
|
"""Return the current schema version string."""
|
|
return SCHEMA_VERSION
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Validation helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class ValidationReport(BaseModel):
|
|
"""Result of validating a raw model response."""
|
|
|
|
valid: bool = False
|
|
errors: list[str] = Field(default_factory=list)
|
|
warnings: list[str] = Field(default_factory=list)
|
|
parsed: ExtractionResult | None = None
|
|
|
|
|
|
def validate_extraction(
|
|
raw_json: str | dict[str, Any],
|
|
*,
|
|
document_text: str = "",
|
|
) -> ValidationReport:
|
|
"""Validate raw model output against the extraction schema.
|
|
|
|
Performs structural (JSON / Pydantic) validation followed by semantic
|
|
checks that catch hallucination indicators, cross-field inconsistencies,
|
|
and data-quality issues.
|
|
|
|
Args:
|
|
raw_json: Either a JSON string or an already-parsed dict.
|
|
document_text: Optional original document text used for evidence
|
|
span verification.
|
|
|
|
Returns:
|
|
A ``ValidationReport`` with parsed result on success.
|
|
"""
|
|
errors: list[str] = []
|
|
warnings: list[str] = []
|
|
|
|
# --- Parse JSON string if needed ---
|
|
if isinstance(raw_json, str):
|
|
try:
|
|
data = json.loads(raw_json)
|
|
except json.JSONDecodeError as exc:
|
|
return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"])
|
|
else:
|
|
data = raw_json
|
|
|
|
if not isinstance(data, dict):
|
|
return ValidationReport(valid=False, errors=["Expected a JSON object at top level."])
|
|
|
|
# --- Pydantic structural validation ---
|
|
try:
|
|
result = ExtractionResult.model_validate(data)
|
|
except Exception as exc: # noqa: BLE001
|
|
return ValidationReport(valid=False, errors=[f"Schema validation failed: {exc}"])
|
|
|
|
# --- Semantic checks ---
|
|
sem_errors, sem_warnings = _semantic_checks(result, document_text)
|
|
errors.extend(sem_errors)
|
|
warnings.extend(sem_warnings)
|
|
|
|
# Semantic errors make the report invalid — the caller should retry.
|
|
valid = len(errors) == 0
|
|
return ValidationReport(
|
|
valid=valid,
|
|
errors=errors,
|
|
warnings=warnings,
|
|
parsed=result,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Known valid impact horizons
|
|
# ---------------------------------------------------------------------------
|
|
|
|
VALID_IMPACT_HORIZONS = frozenset({
|
|
"intraday",
|
|
"1d",
|
|
"1d_7d",
|
|
"1d_30d",
|
|
"30d_90d",
|
|
"90d_plus",
|
|
})
|
|
|
|
# Ticker: 1-5 uppercase letters (covers NYSE, NASDAQ, etc.)
|
|
_TICKER_RE = re.compile(r"^[A-Z]{1,5}$")
|
|
|
|
# Evidence span length bounds (characters)
|
|
_MIN_EVIDENCE_LEN = 8
|
|
_MAX_EVIDENCE_LEN = 500
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Semantic validation rules
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _semantic_checks(
|
|
result: ExtractionResult,
|
|
document_text: str = "",
|
|
) -> tuple[list[str], list[str]]:
|
|
"""Run semantic checks on a parsed extraction.
|
|
|
|
Returns a tuple of (errors, warnings). Errors are issues severe enough
|
|
to warrant a retry; warnings are informational.
|
|
"""
|
|
errors: list[str] = []
|
|
warnings: list[str] = []
|
|
|
|
# --- Top-level checks ---
|
|
if not result.summary:
|
|
warnings.append("empty_summary")
|
|
|
|
if result.confidence < 0.3 and len(result.companies) > 0:
|
|
warnings.append("low_confidence_with_companies")
|
|
|
|
# Duplicate tickers across company entries
|
|
tickers_seen: list[str] = []
|
|
for comp in result.companies:
|
|
if comp.ticker in tickers_seen:
|
|
errors.append(f"duplicate_ticker_{comp.ticker}")
|
|
tickers_seen.append(comp.ticker)
|
|
|
|
# --- Per-company checks ---
|
|
for comp in result.companies:
|
|
tag = comp.ticker or "unknown"
|
|
|
|
# Ticker format
|
|
if not comp.ticker:
|
|
errors.append("company_missing_ticker")
|
|
elif not _TICKER_RE.match(comp.ticker):
|
|
warnings.append(f"invalid_ticker_format_{tag}")
|
|
|
|
# Impact horizon must be a known value
|
|
if comp.impact_horizon not in VALID_IMPACT_HORIZONS:
|
|
errors.append(f"invalid_impact_horizon_{comp.impact_horizon}_for_{tag}")
|
|
|
|
# Evidence spans
|
|
if not comp.evidence_spans:
|
|
warnings.append(f"no_evidence_spans_for_{tag}")
|
|
else:
|
|
for idx, span in enumerate(comp.evidence_spans):
|
|
if len(span) < _MIN_EVIDENCE_LEN:
|
|
warnings.append(f"evidence_span_too_short_for_{tag}_{idx}")
|
|
if len(span) > _MAX_EVIDENCE_LEN:
|
|
warnings.append(f"evidence_span_too_long_for_{tag}_{idx}")
|
|
|
|
# Cross-field: high impact but no facts
|
|
if not comp.key_facts and comp.impact_score > 0.5:
|
|
warnings.append(f"high_impact_no_facts_for_{tag}")
|
|
|
|
# Cross-field: very low relevance
|
|
if comp.relevance < 0.2:
|
|
warnings.append(f"very_low_relevance_for_{tag}")
|
|
|
|
# Cross-field: strong sentiment but low impact
|
|
if comp.sentiment in (Sentiment.POSITIVE, Sentiment.NEGATIVE) and comp.impact_score < 0.1:
|
|
warnings.append(f"strong_sentiment_low_impact_for_{tag}")
|
|
|
|
# --- Evidence grounding check (when source text is available) ---
|
|
if document_text:
|
|
doc_lower = document_text.lower()
|
|
for comp in result.companies:
|
|
for idx, span in enumerate(comp.evidence_spans):
|
|
if span.lower() not in doc_lower:
|
|
warnings.append(
|
|
f"evidence_span_not_found_in_document_for_{comp.ticker or 'unknown'}_{idx}"
|
|
)
|
|
|
|
return errors, warnings
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]:
|
|
"""Recursively inline ``$defs`` / ``$ref`` so the schema is self-contained."""
|
|
defs = schema.pop("$defs", {})
|
|
return _resolve_refs(schema, defs)
|
|
|
|
|
|
def _resolve_refs(node: Any, defs: dict[str, Any]) -> Any:
|
|
"""Walk the schema tree and replace ``$ref`` pointers with their definitions."""
|
|
if isinstance(node, dict):
|
|
if "$ref" in node:
|
|
ref_path = node["$ref"] # e.g. "#/$defs/CompanyExtractionItem"
|
|
ref_name = ref_path.rsplit("/", 1)[-1]
|
|
if ref_name in defs:
|
|
resolved = defs[ref_name].copy()
|
|
# The resolved def may itself contain refs
|
|
return _resolve_refs(resolved, defs)
|
|
return node # unresolvable ref, leave as-is
|
|
return {k: _resolve_refs(v, defs) for k, v in node.items()}
|
|
if isinstance(node, list):
|
|
return [_resolve_refs(item, defs) for item in node]
|
|
return node
|