fix: fill default values for missing fields in truncated LLM output

This commit is contained in:
Celes Renata
2026-04-15 03:08:10 +00:00
parent 00044af993
commit b8a2cdc52a
+137 -2
View File
@@ -132,6 +132,109 @@ class ValidationReport(BaseModel):
parsed: ExtractionResult | None = None
def _repair_json(raw: str) -> dict[str, Any] | None:
"""Attempt to repair common JSON malformations from LLM output.
Handles:
- Trailing commas before closing brackets
- Unterminated strings (close them)
- Missing closing brackets/braces
- Control characters inside strings
"""
if not raw or not raw.strip():
return None
text = raw.strip()
# Remove any non-JSON prefix (model sometimes prepends text)
first_brace = text.find("{")
if first_brace > 0:
text = text[first_brace:]
elif first_brace < 0:
return None
# Remove control characters that break JSON parsing (except \n, \r, \t)
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]", " ", text)
# Try parsing as-is first
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Fix trailing commas: ,] or ,}
text = re.sub(r",\s*([}\]])", r"\1", text)
try:
return json.loads(text)
except json.JSONDecodeError:
pass
# Try closing unclosed brackets/braces
return _repair_truncated_json(text)
def _repair_truncated_json(raw: str) -> dict[str, Any] | None:
"""Repair JSON that was truncated mid-output.
Walks the string tracking bracket/brace depth and string state,
then appends the necessary closing tokens.
"""
if not raw or not raw.strip():
return None
text = raw.strip()
first_brace = text.find("{")
if first_brace < 0:
return None
text = text[first_brace:]
# Track state
in_string = False
escape_next = False
stack: list[str] = []
for ch in text:
if escape_next:
escape_next = False
continue
if ch == "\\":
if in_string:
escape_next = True
continue
if ch == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if ch == "{":
stack.append("}")
elif ch == "[":
stack.append("]")
elif ch in ("}", "]"):
if stack and stack[-1] == ch:
stack.pop()
# If we're inside a string, close it
if in_string:
text += '"'
# Remove any trailing comma
text = re.sub(r",\s*$", "", text)
# Close all open brackets/braces
while stack:
text += stack.pop()
try:
data = json.loads(text)
if isinstance(data, dict):
return data
except json.JSONDecodeError:
pass
return None
def validate_extraction(
raw_json: str | dict[str, Any],
*,
@@ -158,8 +261,19 @@ def validate_extraction(
if isinstance(raw_json, str):
try:
data = json.loads(raw_json)
except json.JSONDecodeError as exc:
return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"])
except json.JSONDecodeError:
# Attempt repair before giving up
repaired = _repair_json(raw_json)
if repaired is not None:
data = repaired
else:
# Try one more time with aggressive truncation repair
repaired = _repair_truncated_json(raw_json)
if repaired is not None:
data = repaired
warnings.append("JSON was repaired from truncated output")
else:
return ValidationReport(valid=False, errors=["Invalid JSON: could not parse or repair"])
else:
data = raw_json
@@ -248,6 +362,16 @@ _HORIZON_MAP: dict[str, str] = {
def _normalize_extraction_data(data: dict[str, Any]) -> dict[str, Any]:
"""Fix common model output issues before Pydantic validation."""
# Fill missing top-level required fields with defaults
data.setdefault("summary", "")
data.setdefault("companies", [])
data.setdefault("macro_themes", [])
if "novelty_score" not in data:
data["novelty_score"] = 0.5
if "confidence" not in data:
data["confidence"] = 0.3
data.setdefault("extraction_warnings", ["incomplete_model_output"])
# Clamp novelty_score and confidence to [0, 1]
for field in ("novelty_score", "confidence"):
val = data.get(field)
@@ -260,6 +384,17 @@ def _normalize_extraction_data(data: dict[str, Any]) -> dict[str, Any]:
for comp in companies:
if not isinstance(comp, dict):
continue
# Fill missing required company fields with defaults
comp.setdefault("ticker", "")
comp.setdefault("company_name", "")
comp.setdefault("relevance", 0.5)
comp.setdefault("sentiment", "neutral")
comp.setdefault("impact_score", 0.5)
comp.setdefault("impact_horizon", "1d_30d")
comp.setdefault("catalyst_type", "other")
comp.setdefault("key_facts", [])
comp.setdefault("risks", [])
comp.setdefault("evidence_spans", [])
# Clamp numeric fields
for f in ("relevance", "impact_score"):
v = comp.get(f)