"""Tests for extractor JSON schema definitions and validation.""" import json from services.extractor.schemas import ( SCHEMA_VERSION, VALID_IMPACT_HORIZONS, ExtractionResult, generate_json_schema, get_schema_version, validate_extraction, ) from services.shared.schemas import CatalystType, Sentiment def test_generate_json_schema_top_level_structure(): """Generated schema is a valid JSON Schema object with required fields.""" schema = generate_json_schema() assert schema["type"] == "object" assert "summary" in schema["required"] assert "companies" in schema["required"] assert "confidence" in schema["required"] assert "extraction_warnings" in schema["required"] def test_generate_json_schema_no_refs(): """Generated schema has no $ref or $defs — fully inlined.""" schema = generate_json_schema() serialized = json.dumps(schema) assert "$ref" not in serialized assert "$defs" not in serialized def test_generate_json_schema_serializable(): """Schema round-trips through JSON serialization.""" schema = generate_json_schema() text = json.dumps(schema) parsed = json.loads(text) assert parsed["type"] == "object" def test_generate_json_schema_company_properties(): """Company items include all required extraction fields.""" schema = generate_json_schema() company_schema = schema["properties"]["companies"]["items"] required = company_schema["required"] assert "ticker" in required assert "sentiment" in required assert "catalyst_type" in required assert "evidence_spans" in required def test_generate_json_schema_enum_values(): """Enum values in the schema match the Pydantic enum definitions.""" schema = generate_json_schema() company_props = schema["properties"]["companies"]["items"]["properties"] sentiment_vals = _extract_enum_values(company_props["sentiment"]) catalyst_vals = _extract_enum_values(company_props["catalyst_type"]) assert set(sentiment_vals) == {s.value for s in Sentiment} assert set(catalyst_vals) == {c.value for c in CatalystType} def test_get_schema_version(): assert get_schema_version() == SCHEMA_VERSION # --- Validation tests --- def _valid_extraction() -> dict: """Minimal valid extraction result.""" return { "summary": "Apple beat earnings expectations.", "companies": [ { "ticker": "AAPL", "company_name": "Apple Inc.", "relevance": 0.95, "sentiment": "positive", "impact_score": 0.7, "impact_horizon": "1d_30d", "catalyst_type": "earnings", "key_facts": ["Revenue up 12%"], "risks": [], "evidence_spans": ["Apple beat expectations"], } ], "macro_themes": ["ai_capex"], "novelty_score": 0.6, "confidence": 0.85, "extraction_warnings": [], } def test_validate_extraction_valid_dict(): report = validate_extraction(_valid_extraction()) assert report.valid assert report.parsed is not None assert report.parsed.companies[0].ticker == "AAPL" def test_validate_extraction_valid_json_string(): report = validate_extraction(json.dumps(_valid_extraction())) assert report.valid assert report.parsed is not None def test_validate_extraction_invalid_json(): report = validate_extraction("{bad json") assert not report.valid assert any("Invalid JSON" in e for e in report.errors) def test_validate_extraction_not_object(): report = validate_extraction("[1, 2, 3]") assert not report.valid assert any("object" in e.lower() for e in report.errors) def test_validate_extraction_missing_required_field(): data = _valid_extraction() del data["summary"] report = validate_extraction(data) # Normalization fills missing summary with "" — validation passes but warns assert report.valid assert "empty_summary" in report.warnings def test_validate_extraction_invalid_enum(): data = _valid_extraction() data["companies"][0]["sentiment"] = "super_bullish" report = validate_extraction(data) assert not report.valid def test_validate_extraction_out_of_range(): data = _valid_extraction() data["confidence"] = 1.5 report = validate_extraction(data) # Normalization clamps confidence to [0, 1] — validation passes assert report.valid assert report.parsed is not None assert report.parsed.confidence == 1.0 def test_validate_semantic_empty_summary_warning(): data = _valid_extraction() data["summary"] = "" report = validate_extraction(data) assert report.valid assert "empty_summary" in report.warnings def test_validate_semantic_low_confidence_with_companies(): data = _valid_extraction() data["confidence"] = 0.2 report = validate_extraction(data) assert report.valid assert "low_confidence_with_companies" in report.warnings def test_validate_semantic_no_evidence_spans(): data = _valid_extraction() data["companies"][0]["evidence_spans"] = [] report = validate_extraction(data) assert report.valid assert any("no_evidence_spans" in w for w in report.warnings) def test_validate_semantic_high_impact_no_facts(): data = _valid_extraction() data["companies"][0]["key_facts"] = [] data["companies"][0]["impact_score"] = 0.8 report = validate_extraction(data) assert report.valid assert any("high_impact_no_facts" in w for w in report.warnings) def test_extraction_result_model_roundtrip(): """ExtractionResult can be created and serialized back to dict.""" result = ExtractionResult( summary="Test", companies=[], macro_themes=[], novelty_score=0.5, confidence=0.5, extraction_warnings=[], ) data = result.model_dump() assert data["summary"] == "Test" reparsed = ExtractionResult.model_validate(data) assert reparsed.summary == "Test" def test_all_top_level_fields_required(): """All top-level fields appear in the schema's required list.""" schema = generate_json_schema() required = set(schema["required"]) expected = {"summary", "companies", "macro_themes", "novelty_score", "confidence", "extraction_warnings"} assert expected == required def test_all_company_fields_required(): """All company item fields appear in the schema's required list.""" schema = generate_json_schema() company_required = set(schema["properties"]["companies"]["items"]["required"]) expected = { "ticker", "company_name", "relevance", "sentiment", "impact_score", "impact_horizon", "catalyst_type", "key_facts", "risks", "evidence_spans", } assert expected == company_required # --- Semantic validation: error-level checks --- def test_validate_semantic_missing_ticker_is_error(): """A company with an empty ticker produces a semantic error, not just a warning.""" data = _valid_extraction() data["companies"][0]["ticker"] = "" report = validate_extraction(data) assert not report.valid assert any("company_missing_ticker" in e for e in report.errors) def test_validate_semantic_invalid_impact_horizon_is_error(): """An unrecognized impact_horizon is normalized to a valid default.""" data = _valid_extraction() data["companies"][0]["impact_horizon"] = "forever" report = validate_extraction(data) # Normalization maps unknown horizons to "1d_30d" — validation passes assert report.valid assert report.parsed is not None assert report.parsed.companies[0].impact_horizon == "1d_30d" def test_validate_semantic_all_valid_horizons_accepted(): """Every value in VALID_IMPACT_HORIZONS passes validation.""" for horizon in VALID_IMPACT_HORIZONS: data = _valid_extraction() data["companies"][0]["impact_horizon"] = horizon report = validate_extraction(data) assert report.valid, f"Horizon {horizon!r} should be valid" def test_validate_semantic_duplicate_ticker_is_error(): """Two company entries with the same ticker produce a semantic error.""" data = _valid_extraction() second = dict(data["companies"][0]) data["companies"].append(second) report = validate_extraction(data) assert not report.valid assert any("duplicate_ticker_AAPL" in e for e in report.errors) # --- Semantic validation: warning-level checks --- def test_validate_semantic_invalid_ticker_format_warning(): """A lowercase or overly long ticker produces a warning.""" data = _valid_extraction() data["companies"][0]["ticker"] = "aapl" report = validate_extraction(data) assert report.valid # warning, not error assert any("invalid_ticker_format" in w for w in report.warnings) def test_validate_semantic_evidence_span_too_short(): data = _valid_extraction() data["companies"][0]["evidence_spans"] = ["short"] report = validate_extraction(data) assert report.valid assert any("evidence_span_too_short" in w for w in report.warnings) def test_validate_semantic_evidence_span_too_long(): data = _valid_extraction() data["companies"][0]["evidence_spans"] = ["x" * 501] report = validate_extraction(data) assert report.valid assert any("evidence_span_too_long" in w for w in report.warnings) def test_validate_semantic_strong_sentiment_low_impact(): data = _valid_extraction() data["companies"][0]["sentiment"] = "positive" data["companies"][0]["impact_score"] = 0.05 report = validate_extraction(data) assert report.valid assert any("strong_sentiment_low_impact" in w for w in report.warnings) # --- Evidence grounding --- def test_validate_evidence_grounding_found(): """Evidence spans present in document_text produce no grounding warnings.""" data = _valid_extraction() doc_text = "Apple beat expectations with record revenue." report = validate_extraction(data, document_text=doc_text) assert report.valid assert not any("evidence_span_not_found" in w for w in report.warnings) def test_validate_evidence_grounding_not_found(): """Evidence spans NOT in document_text produce a grounding warning.""" data = _valid_extraction() doc_text = "Completely unrelated document about weather." report = validate_extraction(data, document_text=doc_text) assert report.valid assert any("evidence_span_not_found" in w for w in report.warnings) # --- Helpers --- def _extract_enum_values(prop: dict) -> list: """Extract enum values from a JSON schema property, handling anyOf patterns.""" if "enum" in prop: return prop["enum"] for option in prop.get("anyOf", []): if "enum" in option: return option["enum"] return []