diff --git a/services/extractor/schemas.py b/services/extractor/schemas.py index 02199df..64d3678 100644 --- a/services/extractor/schemas.py +++ b/services/extractor/schemas.py @@ -132,6 +132,109 @@ class ValidationReport(BaseModel): parsed: ExtractionResult | None = None +def _repair_json(raw: str) -> dict[str, Any] | None: + """Attempt to repair common JSON malformations from LLM output. + + Handles: + - Trailing commas before closing brackets + - Unterminated strings (close them) + - Missing closing brackets/braces + - Control characters inside strings + """ + if not raw or not raw.strip(): + return None + + text = raw.strip() + + # Remove any non-JSON prefix (model sometimes prepends text) + first_brace = text.find("{") + if first_brace > 0: + text = text[first_brace:] + elif first_brace < 0: + return None + + # Remove control characters that break JSON parsing (except \n, \r, \t) + text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]", " ", text) + + # Try parsing as-is first + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + # Fix trailing commas: ,] or ,} + text = re.sub(r",\s*([}\]])", r"\1", text) + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + # Try closing unclosed brackets/braces + return _repair_truncated_json(text) + + +def _repair_truncated_json(raw: str) -> dict[str, Any] | None: + """Repair JSON that was truncated mid-output. + + Walks the string tracking bracket/brace depth and string state, + then appends the necessary closing tokens. + """ + if not raw or not raw.strip(): + return None + + text = raw.strip() + first_brace = text.find("{") + if first_brace < 0: + return None + text = text[first_brace:] + + # Track state + in_string = False + escape_next = False + stack: list[str] = [] + + for ch in text: + if escape_next: + escape_next = False + continue + if ch == "\\": + if in_string: + escape_next = True + continue + if ch == '"' and not escape_next: + in_string = not in_string + continue + if in_string: + continue + if ch == "{": + stack.append("}") + elif ch == "[": + stack.append("]") + elif ch in ("}", "]"): + if stack and stack[-1] == ch: + stack.pop() + + # If we're inside a string, close it + if in_string: + text += '"' + + # Remove any trailing comma + text = re.sub(r",\s*$", "", text) + + # Close all open brackets/braces + while stack: + text += stack.pop() + + try: + data = json.loads(text) + if isinstance(data, dict): + return data + except json.JSONDecodeError: + pass + + return None + + def validate_extraction( raw_json: str | dict[str, Any], *, @@ -158,8 +261,19 @@ def validate_extraction( if isinstance(raw_json, str): try: data = json.loads(raw_json) - except json.JSONDecodeError as exc: - return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"]) + except json.JSONDecodeError: + # Attempt repair before giving up + repaired = _repair_json(raw_json) + if repaired is not None: + data = repaired + else: + # Try one more time with aggressive truncation repair + repaired = _repair_truncated_json(raw_json) + if repaired is not None: + data = repaired + warnings.append("JSON was repaired from truncated output") + else: + return ValidationReport(valid=False, errors=["Invalid JSON: could not parse or repair"]) else: data = raw_json @@ -248,6 +362,16 @@ _HORIZON_MAP: dict[str, str] = { def _normalize_extraction_data(data: dict[str, Any]) -> dict[str, Any]: """Fix common model output issues before Pydantic validation.""" + # Fill missing top-level required fields with defaults + data.setdefault("summary", "") + data.setdefault("companies", []) + data.setdefault("macro_themes", []) + if "novelty_score" not in data: + data["novelty_score"] = 0.5 + if "confidence" not in data: + data["confidence"] = 0.3 + data.setdefault("extraction_warnings", ["incomplete_model_output"]) + # Clamp novelty_score and confidence to [0, 1] for field in ("novelty_score", "confidence"): val = data.get(field) @@ -260,6 +384,17 @@ def _normalize_extraction_data(data: dict[str, Any]) -> dict[str, Any]: for comp in companies: if not isinstance(comp, dict): continue + # Fill missing required company fields with defaults + comp.setdefault("ticker", "") + comp.setdefault("company_name", "") + comp.setdefault("relevance", 0.5) + comp.setdefault("sentiment", "neutral") + comp.setdefault("impact_score", 0.5) + comp.setdefault("impact_horizon", "1d_30d") + comp.setdefault("catalyst_type", "other") + comp.setdefault("key_facts", []) + comp.setdefault("risks", []) + comp.setdefault("evidence_spans", []) # Clamp numeric fields for f in ("relevance", "impact_score"): v = comp.get(f)