209 lines
7.5 KiB
Python
209 lines
7.5 KiB
Python
"""Replay dataset tests for deterministic extraction validation.
|
|
|
|
Loads archived document fixtures and validates that their expected
|
|
extraction outputs still pass the current schema and semantic checks.
|
|
This catches schema regressions, prompt contract changes, and
|
|
validation rule drift without requiring a live Ollama instance.
|
|
|
|
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from services.extractor.replay import (
|
|
FIXTURES_DIR,
|
|
compare_extraction,
|
|
load_all_fixtures,
|
|
load_fixture,
|
|
validate_all_fixtures,
|
|
validate_fixture,
|
|
)
|
|
from services.extractor.schemas import (
|
|
ExtractionResult,
|
|
get_schema_version,
|
|
validate_extraction,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixture loading
|
|
# ---------------------------------------------------------------------------
|
|
|
|
FIXTURE_DIR = FIXTURES_DIR
|
|
|
|
|
|
def _fixture_paths() -> list[Path]:
|
|
"""Collect all .json fixture files."""
|
|
if not FIXTURE_DIR.is_dir():
|
|
return []
|
|
return sorted(FIXTURE_DIR.glob("*.json"))
|
|
|
|
|
|
def test_fixtures_directory_exists():
|
|
"""The replay fixtures directory exists and contains JSON files."""
|
|
assert FIXTURE_DIR.is_dir(), f"Missing fixtures dir: {FIXTURE_DIR}"
|
|
paths = _fixture_paths()
|
|
assert len(paths) >= 3, f"Expected at least 3 fixtures, found {len(paths)}"
|
|
|
|
|
|
def test_load_all_fixtures():
|
|
"""All fixture files load without errors."""
|
|
fixtures = load_all_fixtures()
|
|
assert len(fixtures) >= 3
|
|
for f in fixtures:
|
|
assert f.document_id
|
|
assert f.document_text
|
|
assert f.expected_extraction
|
|
|
|
|
|
def test_fixture_ids_unique():
|
|
"""Every fixture has a unique document_id."""
|
|
fixtures = load_all_fixtures()
|
|
ids = [f.document_id for f in fixtures]
|
|
assert len(ids) == len(set(ids)), f"Duplicate fixture IDs: {ids}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schema validation — the core deterministic test
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_all_expected_extractions_pass_schema():
|
|
"""Every fixture's expected_extraction passes current schema validation.
|
|
|
|
This is the primary regression gate. If a fixture fails here, either
|
|
the fixture needs updating or the schema change is breaking.
|
|
"""
|
|
results = validate_all_fixtures()
|
|
assert len(results) >= 3
|
|
|
|
failures = [r for r in results if not r.schema_valid]
|
|
if failures:
|
|
msgs = []
|
|
for f in failures:
|
|
errs = f.validation_report.errors if f.validation_report else [f.error or "unknown"]
|
|
msgs.append(f" {f.fixture_id}: {errs}")
|
|
pytest.fail(
|
|
f"{len(failures)} fixture(s) failed schema validation:\n" + "\n".join(msgs)
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem)
|
|
def test_individual_fixture_schema_valid(fixture_path: Path):
|
|
"""Each fixture individually passes schema and semantic validation."""
|
|
fixture = load_fixture(fixture_path)
|
|
result = validate_fixture(fixture)
|
|
assert result.schema_valid, (
|
|
f"Fixture {fixture.document_id} failed: "
|
|
f"{result.validation_report.errors if result.validation_report else result.error}"
|
|
)
|
|
assert result.schema_version == get_schema_version()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Expected extraction structural checks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem)
|
|
def test_expected_extraction_roundtrips(fixture_path: Path):
|
|
"""Expected extraction can be parsed into ExtractionResult and back."""
|
|
fixture = load_fixture(fixture_path)
|
|
parsed = fixture.expected_result
|
|
dumped = parsed.model_dump(mode="json")
|
|
reparsed = ExtractionResult.model_validate(dumped)
|
|
assert reparsed.summary == parsed.summary
|
|
assert len(reparsed.companies) == len(parsed.companies)
|
|
|
|
|
|
def test_low_quality_fixture_has_empty_companies():
|
|
"""The low-quality garbled fixture should have no companies."""
|
|
fixtures = load_all_fixtures()
|
|
low_q = [f for f in fixtures if "low-quality" in f.document_id]
|
|
assert len(low_q) == 1
|
|
fixture = low_q[0]
|
|
assert len(fixture.expected_result.companies) == 0
|
|
assert fixture.expected_result.confidence <= 0.3
|
|
|
|
|
|
def test_multi_company_fixture_has_multiple_tickers():
|
|
"""The multi-company fixture should reference multiple companies."""
|
|
fixtures = load_all_fixtures()
|
|
multi = [f for f in fixtures if "multi-company" in f.document_id]
|
|
assert len(multi) == 1
|
|
fixture = multi[0]
|
|
tickers = [c.ticker for c in fixture.expected_result.companies]
|
|
assert len(tickers) >= 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Evidence grounding checks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem)
|
|
def test_evidence_spans_grounded_in_document(fixture_path: Path):
|
|
"""Evidence spans in expected extractions appear in the document text."""
|
|
fixture = load_fixture(fixture_path)
|
|
report = validate_extraction(
|
|
fixture.expected_extraction,
|
|
document_text=fixture.document_text,
|
|
)
|
|
grounding_warnings = [
|
|
w for w in report.warnings if "evidence_span_not_found" in w
|
|
]
|
|
assert not grounding_warnings, (
|
|
f"Fixture {fixture.document_id} has ungrounded evidence: {grounding_warnings}"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Comparison logic tests (using synthetic data, no Ollama needed)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_compare_extraction_perfect_match():
|
|
"""Comparison reports match when actual equals expected."""
|
|
fixtures = load_all_fixtures()
|
|
fixture = fixtures[0]
|
|
actual = fixture.expected_result # identical
|
|
result = compare_extraction(fixture, actual)
|
|
assert result.companies_match
|
|
assert result.sentiment_match
|
|
assert result.catalyst_match
|
|
assert result.actual_schema_valid
|
|
|
|
|
|
def test_compare_extraction_company_mismatch():
|
|
"""Comparison detects when actual has different companies."""
|
|
fixtures = load_all_fixtures()
|
|
# Pick a fixture with companies
|
|
fixture = [f for f in fixtures if f.expected_result.companies][0]
|
|
# Build an actual result with no companies
|
|
actual = ExtractionResult(
|
|
summary="Different",
|
|
companies=[],
|
|
macro_themes=[],
|
|
novelty_score=0.5,
|
|
confidence=0.5,
|
|
extraction_warnings=[],
|
|
)
|
|
result = compare_extraction(fixture, actual)
|
|
assert not result.companies_match
|
|
assert any("company_mismatch" in w for w in result.warnings)
|
|
|
|
|
|
def test_compare_extraction_sentiment_mismatch():
|
|
"""Comparison detects sentiment drift."""
|
|
fixtures = load_all_fixtures()
|
|
fixture = [f for f in fixtures if f.expected_result.companies][0]
|
|
# Clone expected but flip sentiment
|
|
actual_data = fixture.expected_extraction.copy()
|
|
actual_data = {**actual_data}
|
|
companies = [dict(c) for c in actual_data["companies"]]
|
|
companies[0]["sentiment"] = "negative" if companies[0]["sentiment"] != "negative" else "positive"
|
|
actual_data["companies"] = companies
|
|
actual = ExtractionResult.model_validate(actual_data)
|
|
result = compare_extraction(fixture, actual)
|
|
assert result.companies_match # same tickers
|
|
assert not result.sentiment_match # different sentiment
|