Files

251 lines
8.0 KiB
Python

"""Replay dataset loader and runner for deterministic extraction testing.
Loads archived document fixtures from JSON files, validates their expected
extraction outputs against the current schema, and provides a runner that
can compare live Ollama extraction results against expected baselines.
This enables:
- Schema regression testing: verify expected outputs still pass validation
- Prompt regression testing: detect drift when prompts or schemas change
- End-to-end replay: run fixtures through a live Ollama and compare
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from services.extractor.schemas import (
ExtractionResult,
ValidationReport,
get_schema_version,
validate_extraction,
)
logger = logging.getLogger("extractor_replay")
FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent / "tests" / "replay_fixtures"
@dataclass
class ReplayFixture:
"""A single replay fixture loaded from disk."""
document_id: str
document_type: str
document_text: str
known_tickers: list[str]
expected_extraction: dict[str, Any]
metadata: dict[str, str]
source_path: str = ""
@property
def expected_result(self) -> ExtractionResult:
"""Parse expected_extraction into a validated ExtractionResult."""
return ExtractionResult.model_validate(self.expected_extraction)
@dataclass
class ReplayValidationResult:
"""Result of validating a single fixture against the current schema."""
fixture_id: str
schema_valid: bool = False
validation_report: ValidationReport | None = None
schema_version: str = ""
error: str | None = None
@dataclass
class ReplayComparisonResult:
"""Result of comparing a live extraction against the expected baseline."""
fixture_id: str
expected_companies: list[str] = field(default_factory=list)
actual_companies: list[str] = field(default_factory=list)
companies_match: bool = False
expected_sentiment_map: dict[str, str] = field(default_factory=dict)
actual_sentiment_map: dict[str, str] = field(default_factory=dict)
sentiment_match: bool = False
expected_catalyst_map: dict[str, str] = field(default_factory=dict)
actual_catalyst_map: dict[str, str] = field(default_factory=dict)
catalyst_match: bool = False
actual_schema_valid: bool = False
warnings: list[str] = field(default_factory=list)
def load_fixture(path: Path) -> ReplayFixture:
"""Load a single replay fixture from a JSON file.
Args:
path: Path to the fixture JSON file.
Returns:
A ReplayFixture with all fields populated.
Raises:
ValueError: If the fixture is missing required fields.
json.JSONDecodeError: If the file is not valid JSON.
"""
with open(path) as f:
data = json.load(f)
required = {"document_id", "document_type", "document_text", "expected_extraction"}
missing = required - set(data.keys())
if missing:
raise ValueError(f"Fixture {path.name} missing required fields: {missing}")
return ReplayFixture(
document_id=data["document_id"],
document_type=data["document_type"],
document_text=data["document_text"],
known_tickers=data.get("known_tickers", []),
expected_extraction=data["expected_extraction"],
metadata=data.get("metadata", {}),
source_path=str(path),
)
def load_all_fixtures(fixtures_dir: Path | None = None) -> list[ReplayFixture]:
"""Load all replay fixtures from the fixtures directory.
Args:
fixtures_dir: Override path to fixtures directory.
Defaults to tests/replay_fixtures/.
Returns:
List of loaded ReplayFixture objects, sorted by document_id.
"""
directory = fixtures_dir or FIXTURES_DIR
if not directory.is_dir():
logger.warning("Fixtures directory not found: %s", directory)
return []
fixtures: list[ReplayFixture] = []
for path in sorted(directory.glob("*.json")):
try:
fixture = load_fixture(path)
fixtures.append(fixture)
except (ValueError, json.JSONDecodeError) as exc:
logger.warning("Skipping invalid fixture %s: %s", path.name, exc)
logger.info("Loaded %d replay fixtures from %s", len(fixtures), directory)
return fixtures
def validate_fixture(fixture: ReplayFixture) -> ReplayValidationResult:
"""Validate a fixture's expected extraction against the current schema.
This is the core deterministic test: the expected output must still
pass schema and semantic validation with the current code. If it
doesn't, either the fixture is stale or the schema has regressed.
Args:
fixture: The replay fixture to validate.
Returns:
A ReplayValidationResult indicating pass/fail.
"""
result = ReplayValidationResult(
fixture_id=fixture.document_id,
schema_version=get_schema_version(),
)
try:
report = validate_extraction(
fixture.expected_extraction,
document_text=fixture.document_text,
)
result.validation_report = report
result.schema_valid = report.valid
except Exception as exc: # noqa: BLE001
result.error = str(exc)
result.schema_valid = False
return result
def validate_all_fixtures(
fixtures_dir: Path | None = None,
) -> list[ReplayValidationResult]:
"""Load and validate all fixtures against the current schema.
Args:
fixtures_dir: Override path to fixtures directory.
Returns:
List of validation results, one per fixture.
"""
fixtures = load_all_fixtures(fixtures_dir)
return [validate_fixture(f) for f in fixtures]
def compare_extraction(
fixture: ReplayFixture,
actual_result: ExtractionResult,
) -> ReplayComparisonResult:
"""Compare a live extraction result against the fixture's expected output.
Checks structural alignment (same companies detected, same sentiments,
same catalyst types) rather than exact string equality, since LLM
outputs vary in wording across runs.
Args:
fixture: The replay fixture with expected output.
actual_result: The ExtractionResult from a live extraction.
Returns:
A ReplayComparisonResult with match details.
"""
expected = fixture.expected_result
comparison = ReplayComparisonResult(fixture_id=fixture.document_id)
# Company ticker sets
comparison.expected_companies = sorted(c.ticker for c in expected.companies)
comparison.actual_companies = sorted(c.ticker for c in actual_result.companies)
comparison.companies_match = (
set(comparison.expected_companies) == set(comparison.actual_companies)
)
# Sentiment by ticker
comparison.expected_sentiment_map = {
c.ticker: c.sentiment for c in expected.companies
}
comparison.actual_sentiment_map = {
c.ticker: c.sentiment for c in actual_result.companies
}
comparison.sentiment_match = (
comparison.expected_sentiment_map == comparison.actual_sentiment_map
)
# Catalyst type by ticker
comparison.expected_catalyst_map = {
c.ticker: c.catalyst_type for c in expected.companies
}
comparison.actual_catalyst_map = {
c.ticker: c.catalyst_type for c in actual_result.companies
}
comparison.catalyst_match = (
comparison.expected_catalyst_map == comparison.actual_catalyst_map
)
# Schema validity of actual result
actual_report = validate_extraction(
actual_result.model_dump(mode="json"),
document_text=fixture.document_text,
)
comparison.actual_schema_valid = actual_report.valid
if actual_report.warnings:
comparison.warnings = actual_report.warnings
if not comparison.companies_match:
comparison.warnings.append(
f"company_mismatch: expected={comparison.expected_companies} actual={comparison.actual_companies}"
)
return comparison