"""JSON schema definitions for document intelligence extraction. Generates Ollama-compatible JSON schemas from Pydantic models so the extraction contract stays in sync with the shared data models. Also provides schema validation and semantic validation helpers. Requirements: 5.1, 5.2, 5.3, 5.4, 5.5 """ from __future__ import annotations import json import re from typing import Any from pydantic import BaseModel, Field from services.shared.schemas import ( CatalystType, Sentiment, ) SCHEMA_VERSION = "2.0.0" # --------------------------------------------------------------------------- # Pydantic model that mirrors the Ollama extraction output contract. # This is the *response* shape we ask the model to produce — it intentionally # omits server-side fields like document_id, source_credibility, and model # metadata that are attached after extraction. # --------------------------------------------------------------------------- class CompanyExtractionItem(BaseModel): """Per-company extraction output expected from the model. All fields are required (no defaults) so the generated JSON schema forces the model to produce every field explicitly. """ ticker: str = Field(description="Stock ticker symbol mentioned in the document.") company_name: str = Field(description="Full company name as referenced in the document.") relevance: float = Field( ge=0, le=1, description="How relevant the document is to this company. 0=tangential, 1=primary subject.", ) sentiment: Sentiment = Field(description="Overall sentiment toward this company in the document.") impact_score: float = Field( ge=0, le=1, description="Estimated magnitude of impact. 0=negligible, 1=highly material.", ) impact_horizon: str = Field( description="One of: intraday, 1d, 1d_7d, 1d_30d, 30d_90d, 90d_plus", ) catalyst_type: CatalystType = Field(description="Primary catalyst category.") key_facts: list[str] = Field( description="Facts explicitly stated in the document. Do NOT infer or fabricate.", ) risks: list[str] = Field( description="Risks explicitly mentioned in the document.", ) evidence_spans: list[str] = Field( description="Short verbatim quotes from the document supporting the analysis.", ) class ExtractionResult(BaseModel): """Top-level structured output the model must return. All fields are required (no defaults) so the generated JSON schema forces the model to produce every field explicitly. """ summary: str = Field( description="A concise 1-3 sentence summary of the document's main point.", ) companies: list[CompanyExtractionItem] = Field( description="Per-company intelligence extracted from the document.", ) macro_themes: list[str] = Field( description="Broad economic or market themes mentioned (e.g. rates, inflation, ai_capex).", ) novelty_score: float = Field( ge=0, le=1, description="How novel or surprising the information is. 0=routine, 1=highly novel.", ) confidence: float = Field( ge=0, le=1, description="Model confidence in the accuracy of this extraction. Lower if text is ambiguous.", ) extraction_warnings: list[str] = Field( description="Any issues encountered: ambiguous_ticker, incomplete_text, low_confidence, etc.", ) # --------------------------------------------------------------------------- # Schema generation # --------------------------------------------------------------------------- def generate_json_schema() -> dict[str, Any]: """Generate the JSON schema from the Pydantic model. Returns a plain JSON Schema dict suitable for Ollama's ``format`` parameter. Pydantic ``$defs`` are inlined so the schema is self-contained. """ raw = ExtractionResult.model_json_schema() # Inline $defs so the schema is flat and Ollama-friendly return _inline_defs(raw) def get_schema_version() -> str: """Return the current schema version string.""" return SCHEMA_VERSION # --------------------------------------------------------------------------- # Validation helpers # --------------------------------------------------------------------------- class ValidationReport(BaseModel): """Result of validating a raw model response.""" valid: bool = False errors: list[str] = Field(default_factory=list) warnings: list[str] = Field(default_factory=list) parsed: ExtractionResult | None = None def validate_extraction( raw_json: str | dict[str, Any], *, document_text: str = "", ) -> ValidationReport: """Validate raw model output against the extraction schema. Performs structural (JSON / Pydantic) validation followed by semantic checks that catch hallucination indicators, cross-field inconsistencies, and data-quality issues. Args: raw_json: Either a JSON string or an already-parsed dict. document_text: Optional original document text used for evidence span verification. Returns: A ``ValidationReport`` with parsed result on success. """ errors: list[str] = [] warnings: list[str] = [] # --- Parse JSON string if needed --- if isinstance(raw_json, str): try: data = json.loads(raw_json) except json.JSONDecodeError as exc: return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"]) else: data = raw_json if not isinstance(data, dict): return ValidationReport(valid=False, errors=["Expected a JSON object at top level."]) # --- Normalize common model output issues before validation --- data = _normalize_extraction_data(data) # --- Pydantic structural validation --- try: result = ExtractionResult.model_validate(data) except Exception as exc: # noqa: BLE001 return ValidationReport(valid=False, errors=[f"Schema validation failed: {exc}"]) # --- Semantic checks --- sem_errors, sem_warnings = _semantic_checks(result, document_text) errors.extend(sem_errors) warnings.extend(sem_warnings) # Semantic errors make the report invalid — the caller should retry. valid = len(errors) == 0 return ValidationReport( valid=valid, errors=errors, warnings=warnings, parsed=result, ) # --------------------------------------------------------------------------- # Normalize model output before validation # --------------------------------------------------------------------------- _CATALYST_ALIASES: dict[str, str] = { "strategic pivot": "other", "strategic": "other", "restructuring": "other", "partnership": "other", "acquisition": "m_and_a", "merger": "m_and_a", "buyout": "m_and_a", "lawsuit": "legal", "regulation": "legal", "regulatory": "legal", "upgrade": "rating_change", "downgrade": "rating_change", "price target": "rating_change", "inflation": "macro", "interest rate": "macro", "interest rates": "macro", "tariff": "macro", "tariffs": "macro", "launch": "product", "product launch": "product", "revenue": "earnings", "profit": "earnings", "guidance": "earnings", "supply": "supply_chain", "shortage": "supply_chain", } _VALID_CATALYSTS = frozenset({ "earnings", "product", "legal", "macro", "supply_chain", "m_and_a", "rating_change", "other", }) _HORIZON_MAP: dict[str, str] = { "long-term": "90d_plus", "long_term": "90d_plus", "long": "90d_plus", "longterm": "90d_plus", "medium-term": "30d_90d", "medium_term": "1d_30d", "medium": "1d_30d", "short-term": "1d_7d", "short_term": "1d", "short": "1d", "immediate": "intraday", "near-term": "1d_7d", "near_term": "1d_7d", "mid-term": "1d_30d", "mid_term": "1d_30d", } def _normalize_extraction_data(data: dict[str, Any]) -> dict[str, Any]: """Fix common model output issues before Pydantic validation.""" # Clamp novelty_score and confidence to [0, 1] for field in ("novelty_score", "confidence"): val = data.get(field) if isinstance(val, (int, float)): data[field] = max(0.0, min(1.0, float(val))) # Normalize company entries companies = data.get("companies", []) if isinstance(companies, list): for comp in companies: if not isinstance(comp, dict): continue # Clamp numeric fields for f in ("relevance", "impact_score"): v = comp.get(f) if isinstance(v, (int, float)): comp[f] = max(0.0, min(1.0, float(v))) # Map impact_horizon alternatives horizon = comp.get("impact_horizon", "") if isinstance(horizon, str): h = horizon.lower().strip() if h not in VALID_IMPACT_HORIZONS: comp["impact_horizon"] = _HORIZON_MAP.get(h, "1d_30d") # Map catalyst_type alternatives cat = comp.get("catalyst_type", "") if isinstance(cat, str) and cat.lower().strip() not in _VALID_CATALYSTS: mapped_cat = _CATALYST_ALIASES.get(cat.lower().strip(), "other") comp["catalyst_type"] = mapped_cat return data # --------------------------------------------------------------------------- # Known valid impact horizons # --------------------------------------------------------------------------- VALID_IMPACT_HORIZONS = frozenset({ "intraday", "1d", "1d_7d", "1d_30d", "30d_90d", "90d_plus", }) # Ticker: 1-5 uppercase letters (covers NYSE, NASDAQ, etc.) _TICKER_RE = re.compile(r"^[A-Z]{1,5}$") # Evidence span length bounds (characters) _MIN_EVIDENCE_LEN = 8 _MAX_EVIDENCE_LEN = 500 # --------------------------------------------------------------------------- # Semantic validation rules # --------------------------------------------------------------------------- def _semantic_checks( result: ExtractionResult, document_text: str = "", ) -> tuple[list[str], list[str]]: """Run semantic checks on a parsed extraction. Returns a tuple of (errors, warnings). Errors are issues severe enough to warrant a retry; warnings are informational. """ errors: list[str] = [] warnings: list[str] = [] # --- Top-level checks --- if not result.summary: warnings.append("empty_summary") if result.confidence < 0.3 and len(result.companies) > 0: warnings.append("low_confidence_with_companies") # Duplicate tickers across company entries tickers_seen: list[str] = [] for comp in result.companies: if comp.ticker in tickers_seen: errors.append(f"duplicate_ticker_{comp.ticker}") tickers_seen.append(comp.ticker) # --- Per-company checks --- for comp in result.companies: tag = comp.ticker or "unknown" # Ticker format if not comp.ticker: errors.append("company_missing_ticker") elif not _TICKER_RE.match(comp.ticker): warnings.append(f"invalid_ticker_format_{tag}") # Impact horizon must be a known value if comp.impact_horizon not in VALID_IMPACT_HORIZONS: errors.append(f"invalid_impact_horizon_{comp.impact_horizon}_for_{tag}") # Evidence spans if not comp.evidence_spans: warnings.append(f"no_evidence_spans_for_{tag}") else: for idx, span in enumerate(comp.evidence_spans): if len(span) < _MIN_EVIDENCE_LEN: warnings.append(f"evidence_span_too_short_for_{tag}_{idx}") if len(span) > _MAX_EVIDENCE_LEN: warnings.append(f"evidence_span_too_long_for_{tag}_{idx}") # Cross-field: high impact but no facts if not comp.key_facts and comp.impact_score > 0.5: warnings.append(f"high_impact_no_facts_for_{tag}") # Cross-field: very low relevance if comp.relevance < 0.2: warnings.append(f"very_low_relevance_for_{tag}") # Cross-field: strong sentiment but low impact if comp.sentiment in (Sentiment.POSITIVE, Sentiment.NEGATIVE) and comp.impact_score < 0.1: warnings.append(f"strong_sentiment_low_impact_for_{tag}") # --- Evidence grounding check (when source text is available) --- if document_text: doc_lower = document_text.lower() for comp in result.companies: for idx, span in enumerate(comp.evidence_spans): if span.lower() not in doc_lower: warnings.append( f"evidence_span_not_found_in_document_for_{comp.ticker or 'unknown'}_{idx}" ) return errors, warnings # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _inline_defs(schema: dict[str, Any]) -> dict[str, Any]: """Recursively inline ``$defs`` / ``$ref`` so the schema is self-contained.""" defs = schema.pop("$defs", {}) return _resolve_refs(schema, defs) def _resolve_refs(node: Any, defs: dict[str, Any]) -> Any: """Walk the schema tree and replace ``$ref`` pointers with their definitions.""" if isinstance(node, dict): if "$ref" in node: ref_path = node["$ref"] # e.g. "#/$defs/CompanyExtractionItem" ref_name = ref_path.rsplit("/", 1)[-1] if ref_name in defs: resolved = defs[ref_name].copy() # The resolved def may itself contain refs return _resolve_refs(resolved, defs) return node # unresolvable ref, leave as-is return {k: _resolve_refs(v, defs) for k, v in node.items()} if isinstance(node, list): return [_resolve_refs(item, defs) for item in node] return node