Files
stonks-oracle/tests/test_replay_extraction.py
T

209 lines
7.5 KiB
Python

"""Replay dataset tests for deterministic extraction validation.
Loads archived document fixtures and validates that their expected
extraction outputs still pass the current schema and semantic checks.
This catches schema regressions, prompt contract changes, and
validation rule drift without requiring a live Ollama instance.
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
from pathlib import Path
import pytest
from services.extractor.replay import (
FIXTURES_DIR,
compare_extraction,
load_all_fixtures,
load_fixture,
validate_all_fixtures,
validate_fixture,
)
from services.extractor.schemas import (
ExtractionResult,
get_schema_version,
validate_extraction,
)
# ---------------------------------------------------------------------------
# Fixture loading
# ---------------------------------------------------------------------------
FIXTURE_DIR = FIXTURES_DIR
def _fixture_paths() -> list[Path]:
"""Collect all .json fixture files."""
if not FIXTURE_DIR.is_dir():
return []
return sorted(FIXTURE_DIR.glob("*.json"))
def test_fixtures_directory_exists():
"""The replay fixtures directory exists and contains JSON files."""
assert FIXTURE_DIR.is_dir(), f"Missing fixtures dir: {FIXTURE_DIR}"
paths = _fixture_paths()
assert len(paths) >= 3, f"Expected at least 3 fixtures, found {len(paths)}"
def test_load_all_fixtures():
"""All fixture files load without errors."""
fixtures = load_all_fixtures()
assert len(fixtures) >= 3
for f in fixtures:
assert f.document_id
assert f.document_text
assert f.expected_extraction
def test_fixture_ids_unique():
"""Every fixture has a unique document_id."""
fixtures = load_all_fixtures()
ids = [f.document_id for f in fixtures]
assert len(ids) == len(set(ids)), f"Duplicate fixture IDs: {ids}"
# ---------------------------------------------------------------------------
# Schema validation — the core deterministic test
# ---------------------------------------------------------------------------
def test_all_expected_extractions_pass_schema():
"""Every fixture's expected_extraction passes current schema validation.
This is the primary regression gate. If a fixture fails here, either
the fixture needs updating or the schema change is breaking.
"""
results = validate_all_fixtures()
assert len(results) >= 3
failures = [r for r in results if not r.schema_valid]
if failures:
msgs = []
for f in failures:
errs = f.validation_report.errors if f.validation_report else [f.error or "unknown"]
msgs.append(f" {f.fixture_id}: {errs}")
pytest.fail(
f"{len(failures)} fixture(s) failed schema validation:\n" + "\n".join(msgs)
)
@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem)
def test_individual_fixture_schema_valid(fixture_path: Path):
"""Each fixture individually passes schema and semantic validation."""
fixture = load_fixture(fixture_path)
result = validate_fixture(fixture)
assert result.schema_valid, (
f"Fixture {fixture.document_id} failed: "
f"{result.validation_report.errors if result.validation_report else result.error}"
)
assert result.schema_version == get_schema_version()
# ---------------------------------------------------------------------------
# Expected extraction structural checks
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem)
def test_expected_extraction_roundtrips(fixture_path: Path):
"""Expected extraction can be parsed into ExtractionResult and back."""
fixture = load_fixture(fixture_path)
parsed = fixture.expected_result
dumped = parsed.model_dump(mode="json")
reparsed = ExtractionResult.model_validate(dumped)
assert reparsed.summary == parsed.summary
assert len(reparsed.companies) == len(parsed.companies)
def test_low_quality_fixture_has_empty_companies():
"""The low-quality garbled fixture should have no companies."""
fixtures = load_all_fixtures()
low_q = [f for f in fixtures if "low-quality" in f.document_id]
assert len(low_q) == 1
fixture = low_q[0]
assert len(fixture.expected_result.companies) == 0
assert fixture.expected_result.confidence <= 0.3
def test_multi_company_fixture_has_multiple_tickers():
"""The multi-company fixture should reference multiple companies."""
fixtures = load_all_fixtures()
multi = [f for f in fixtures if "multi-company" in f.document_id]
assert len(multi) == 1
fixture = multi[0]
tickers = [c.ticker for c in fixture.expected_result.companies]
assert len(tickers) >= 3
# ---------------------------------------------------------------------------
# Evidence grounding checks
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("fixture_path", _fixture_paths(), ids=lambda p: p.stem)
def test_evidence_spans_grounded_in_document(fixture_path: Path):
"""Evidence spans in expected extractions appear in the document text."""
fixture = load_fixture(fixture_path)
report = validate_extraction(
fixture.expected_extraction,
document_text=fixture.document_text,
)
grounding_warnings = [
w for w in report.warnings if "evidence_span_not_found" in w
]
assert not grounding_warnings, (
f"Fixture {fixture.document_id} has ungrounded evidence: {grounding_warnings}"
)
# ---------------------------------------------------------------------------
# Comparison logic tests (using synthetic data, no Ollama needed)
# ---------------------------------------------------------------------------
def test_compare_extraction_perfect_match():
"""Comparison reports match when actual equals expected."""
fixtures = load_all_fixtures()
fixture = fixtures[0]
actual = fixture.expected_result # identical
result = compare_extraction(fixture, actual)
assert result.companies_match
assert result.sentiment_match
assert result.catalyst_match
assert result.actual_schema_valid
def test_compare_extraction_company_mismatch():
"""Comparison detects when actual has different companies."""
fixtures = load_all_fixtures()
# Pick a fixture with companies
fixture = [f for f in fixtures if f.expected_result.companies][0]
# Build an actual result with no companies
actual = ExtractionResult(
summary="Different",
companies=[],
macro_themes=[],
novelty_score=0.5,
confidence=0.5,
extraction_warnings=[],
)
result = compare_extraction(fixture, actual)
assert not result.companies_match
assert any("company_mismatch" in w for w in result.warnings)
def test_compare_extraction_sentiment_mismatch():
"""Comparison detects sentiment drift."""
fixtures = load_all_fixtures()
fixture = [f for f in fixtures if f.expected_result.companies][0]
# Clone expected but flip sentiment
actual_data = fixture.expected_extraction.copy()
actual_data = {**actual_data}
companies = [dict(c) for c in actual_data["companies"]]
companies[0]["sentiment"] = "negative" if companies[0]["sentiment"] != "negative" else "positive"
actual_data["companies"] = companies
actual = ExtractionResult.model_validate(actual_data)
result = compare_extraction(fixture, actual)
assert result.companies_match # same tickers
assert not result.sentiment_match # different sentiment