fix: fill default values for missing fields in truncated LLM output
This commit is contained in:
@@ -132,6 +132,109 @@ class ValidationReport(BaseModel):
|
||||
parsed: ExtractionResult | None = None
|
||||
|
||||
|
||||
def _repair_json(raw: str) -> dict[str, Any] | None:
|
||||
"""Attempt to repair common JSON malformations from LLM output.
|
||||
|
||||
Handles:
|
||||
- Trailing commas before closing brackets
|
||||
- Unterminated strings (close them)
|
||||
- Missing closing brackets/braces
|
||||
- Control characters inside strings
|
||||
"""
|
||||
if not raw or not raw.strip():
|
||||
return None
|
||||
|
||||
text = raw.strip()
|
||||
|
||||
# Remove any non-JSON prefix (model sometimes prepends text)
|
||||
first_brace = text.find("{")
|
||||
if first_brace > 0:
|
||||
text = text[first_brace:]
|
||||
elif first_brace < 0:
|
||||
return None
|
||||
|
||||
# Remove control characters that break JSON parsing (except \n, \r, \t)
|
||||
text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f]", " ", text)
|
||||
|
||||
# Try parsing as-is first
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Fix trailing commas: ,] or ,}
|
||||
text = re.sub(r",\s*([}\]])", r"\1", text)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try closing unclosed brackets/braces
|
||||
return _repair_truncated_json(text)
|
||||
|
||||
|
||||
def _repair_truncated_json(raw: str) -> dict[str, Any] | None:
|
||||
"""Repair JSON that was truncated mid-output.
|
||||
|
||||
Walks the string tracking bracket/brace depth and string state,
|
||||
then appends the necessary closing tokens.
|
||||
"""
|
||||
if not raw or not raw.strip():
|
||||
return None
|
||||
|
||||
text = raw.strip()
|
||||
first_brace = text.find("{")
|
||||
if first_brace < 0:
|
||||
return None
|
||||
text = text[first_brace:]
|
||||
|
||||
# Track state
|
||||
in_string = False
|
||||
escape_next = False
|
||||
stack: list[str] = []
|
||||
|
||||
for ch in text:
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if ch == "\\":
|
||||
if in_string:
|
||||
escape_next = True
|
||||
continue
|
||||
if ch == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if ch == "{":
|
||||
stack.append("}")
|
||||
elif ch == "[":
|
||||
stack.append("]")
|
||||
elif ch in ("}", "]"):
|
||||
if stack and stack[-1] == ch:
|
||||
stack.pop()
|
||||
|
||||
# If we're inside a string, close it
|
||||
if in_string:
|
||||
text += '"'
|
||||
|
||||
# Remove any trailing comma
|
||||
text = re.sub(r",\s*$", "", text)
|
||||
|
||||
# Close all open brackets/braces
|
||||
while stack:
|
||||
text += stack.pop()
|
||||
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def validate_extraction(
|
||||
raw_json: str | dict[str, Any],
|
||||
*,
|
||||
@@ -158,8 +261,19 @@ def validate_extraction(
|
||||
if isinstance(raw_json, str):
|
||||
try:
|
||||
data = json.loads(raw_json)
|
||||
except json.JSONDecodeError as exc:
|
||||
return ValidationReport(valid=False, errors=[f"Invalid JSON: {exc}"])
|
||||
except json.JSONDecodeError:
|
||||
# Attempt repair before giving up
|
||||
repaired = _repair_json(raw_json)
|
||||
if repaired is not None:
|
||||
data = repaired
|
||||
else:
|
||||
# Try one more time with aggressive truncation repair
|
||||
repaired = _repair_truncated_json(raw_json)
|
||||
if repaired is not None:
|
||||
data = repaired
|
||||
warnings.append("JSON was repaired from truncated output")
|
||||
else:
|
||||
return ValidationReport(valid=False, errors=["Invalid JSON: could not parse or repair"])
|
||||
else:
|
||||
data = raw_json
|
||||
|
||||
@@ -248,6 +362,16 @@ _HORIZON_MAP: dict[str, str] = {
|
||||
|
||||
def _normalize_extraction_data(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Fix common model output issues before Pydantic validation."""
|
||||
# Fill missing top-level required fields with defaults
|
||||
data.setdefault("summary", "")
|
||||
data.setdefault("companies", [])
|
||||
data.setdefault("macro_themes", [])
|
||||
if "novelty_score" not in data:
|
||||
data["novelty_score"] = 0.5
|
||||
if "confidence" not in data:
|
||||
data["confidence"] = 0.3
|
||||
data.setdefault("extraction_warnings", ["incomplete_model_output"])
|
||||
|
||||
# Clamp novelty_score and confidence to [0, 1]
|
||||
for field in ("novelty_score", "confidence"):
|
||||
val = data.get(field)
|
||||
@@ -260,6 +384,17 @@ def _normalize_extraction_data(data: dict[str, Any]) -> dict[str, Any]:
|
||||
for comp in companies:
|
||||
if not isinstance(comp, dict):
|
||||
continue
|
||||
# Fill missing required company fields with defaults
|
||||
comp.setdefault("ticker", "")
|
||||
comp.setdefault("company_name", "")
|
||||
comp.setdefault("relevance", 0.5)
|
||||
comp.setdefault("sentiment", "neutral")
|
||||
comp.setdefault("impact_score", 0.5)
|
||||
comp.setdefault("impact_horizon", "1d_30d")
|
||||
comp.setdefault("catalyst_type", "other")
|
||||
comp.setdefault("key_facts", [])
|
||||
comp.setdefault("risks", [])
|
||||
comp.setdefault("evidence_spans", [])
|
||||
# Clamp numeric fields
|
||||
for f in ("relevance", "impact_score"):
|
||||
v = comp.get(f)
|
||||
|
||||
Reference in New Issue
Block a user