251 lines
8.0 KiB
Python
251 lines
8.0 KiB
Python
"""Replay dataset loader and runner for deterministic extraction testing.
|
|
|
|
Loads archived document fixtures from JSON files, validates their expected
|
|
extraction outputs against the current schema, and provides a runner that
|
|
can compare live Ollama extraction results against expected baselines.
|
|
|
|
This enables:
|
|
- Schema regression testing: verify expected outputs still pass validation
|
|
- Prompt regression testing: detect drift when prompts or schemas change
|
|
- End-to-end replay: run fixtures through a live Ollama and compare
|
|
|
|
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from services.extractor.schemas import (
|
|
ExtractionResult,
|
|
ValidationReport,
|
|
get_schema_version,
|
|
validate_extraction,
|
|
)
|
|
|
|
logger = logging.getLogger("extractor_replay")
|
|
|
|
FIXTURES_DIR = Path(__file__).resolve().parent.parent.parent / "tests" / "replay_fixtures"
|
|
|
|
|
|
@dataclass
|
|
class ReplayFixture:
|
|
"""A single replay fixture loaded from disk."""
|
|
|
|
document_id: str
|
|
document_type: str
|
|
document_text: str
|
|
known_tickers: list[str]
|
|
expected_extraction: dict[str, Any]
|
|
metadata: dict[str, str]
|
|
source_path: str = ""
|
|
|
|
@property
|
|
def expected_result(self) -> ExtractionResult:
|
|
"""Parse expected_extraction into a validated ExtractionResult."""
|
|
return ExtractionResult.model_validate(self.expected_extraction)
|
|
|
|
|
|
@dataclass
|
|
class ReplayValidationResult:
|
|
"""Result of validating a single fixture against the current schema."""
|
|
|
|
fixture_id: str
|
|
schema_valid: bool = False
|
|
validation_report: ValidationReport | None = None
|
|
schema_version: str = ""
|
|
error: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class ReplayComparisonResult:
|
|
"""Result of comparing a live extraction against the expected baseline."""
|
|
|
|
fixture_id: str
|
|
expected_companies: list[str] = field(default_factory=list)
|
|
actual_companies: list[str] = field(default_factory=list)
|
|
companies_match: bool = False
|
|
expected_sentiment_map: dict[str, str] = field(default_factory=dict)
|
|
actual_sentiment_map: dict[str, str] = field(default_factory=dict)
|
|
sentiment_match: bool = False
|
|
expected_catalyst_map: dict[str, str] = field(default_factory=dict)
|
|
actual_catalyst_map: dict[str, str] = field(default_factory=dict)
|
|
catalyst_match: bool = False
|
|
actual_schema_valid: bool = False
|
|
warnings: list[str] = field(default_factory=list)
|
|
|
|
|
|
def load_fixture(path: Path) -> ReplayFixture:
|
|
"""Load a single replay fixture from a JSON file.
|
|
|
|
Args:
|
|
path: Path to the fixture JSON file.
|
|
|
|
Returns:
|
|
A ReplayFixture with all fields populated.
|
|
|
|
Raises:
|
|
ValueError: If the fixture is missing required fields.
|
|
json.JSONDecodeError: If the file is not valid JSON.
|
|
"""
|
|
with open(path) as f:
|
|
data = json.load(f)
|
|
|
|
required = {"document_id", "document_type", "document_text", "expected_extraction"}
|
|
missing = required - set(data.keys())
|
|
if missing:
|
|
raise ValueError(f"Fixture {path.name} missing required fields: {missing}")
|
|
|
|
return ReplayFixture(
|
|
document_id=data["document_id"],
|
|
document_type=data["document_type"],
|
|
document_text=data["document_text"],
|
|
known_tickers=data.get("known_tickers", []),
|
|
expected_extraction=data["expected_extraction"],
|
|
metadata=data.get("metadata", {}),
|
|
source_path=str(path),
|
|
)
|
|
|
|
|
|
def load_all_fixtures(fixtures_dir: Path | None = None) -> list[ReplayFixture]:
|
|
"""Load all replay fixtures from the fixtures directory.
|
|
|
|
Args:
|
|
fixtures_dir: Override path to fixtures directory.
|
|
Defaults to tests/replay_fixtures/.
|
|
|
|
Returns:
|
|
List of loaded ReplayFixture objects, sorted by document_id.
|
|
"""
|
|
directory = fixtures_dir or FIXTURES_DIR
|
|
if not directory.is_dir():
|
|
logger.warning("Fixtures directory not found: %s", directory)
|
|
return []
|
|
|
|
fixtures: list[ReplayFixture] = []
|
|
for path in sorted(directory.glob("*.json")):
|
|
try:
|
|
fixture = load_fixture(path)
|
|
fixtures.append(fixture)
|
|
except (ValueError, json.JSONDecodeError) as exc:
|
|
logger.warning("Skipping invalid fixture %s: %s", path.name, exc)
|
|
|
|
logger.info("Loaded %d replay fixtures from %s", len(fixtures), directory)
|
|
return fixtures
|
|
|
|
|
|
def validate_fixture(fixture: ReplayFixture) -> ReplayValidationResult:
|
|
"""Validate a fixture's expected extraction against the current schema.
|
|
|
|
This is the core deterministic test: the expected output must still
|
|
pass schema and semantic validation with the current code. If it
|
|
doesn't, either the fixture is stale or the schema has regressed.
|
|
|
|
Args:
|
|
fixture: The replay fixture to validate.
|
|
|
|
Returns:
|
|
A ReplayValidationResult indicating pass/fail.
|
|
"""
|
|
result = ReplayValidationResult(
|
|
fixture_id=fixture.document_id,
|
|
schema_version=get_schema_version(),
|
|
)
|
|
|
|
try:
|
|
report = validate_extraction(
|
|
fixture.expected_extraction,
|
|
document_text=fixture.document_text,
|
|
)
|
|
result.validation_report = report
|
|
result.schema_valid = report.valid
|
|
except Exception as exc: # noqa: BLE001
|
|
result.error = str(exc)
|
|
result.schema_valid = False
|
|
|
|
return result
|
|
|
|
|
|
def validate_all_fixtures(
|
|
fixtures_dir: Path | None = None,
|
|
) -> list[ReplayValidationResult]:
|
|
"""Load and validate all fixtures against the current schema.
|
|
|
|
Args:
|
|
fixtures_dir: Override path to fixtures directory.
|
|
|
|
Returns:
|
|
List of validation results, one per fixture.
|
|
"""
|
|
fixtures = load_all_fixtures(fixtures_dir)
|
|
return [validate_fixture(f) for f in fixtures]
|
|
|
|
|
|
def compare_extraction(
|
|
fixture: ReplayFixture,
|
|
actual_result: ExtractionResult,
|
|
) -> ReplayComparisonResult:
|
|
"""Compare a live extraction result against the fixture's expected output.
|
|
|
|
Checks structural alignment (same companies detected, same sentiments,
|
|
same catalyst types) rather than exact string equality, since LLM
|
|
outputs vary in wording across runs.
|
|
|
|
Args:
|
|
fixture: The replay fixture with expected output.
|
|
actual_result: The ExtractionResult from a live extraction.
|
|
|
|
Returns:
|
|
A ReplayComparisonResult with match details.
|
|
"""
|
|
expected = fixture.expected_result
|
|
comparison = ReplayComparisonResult(fixture_id=fixture.document_id)
|
|
|
|
# Company ticker sets
|
|
comparison.expected_companies = sorted(c.ticker for c in expected.companies)
|
|
comparison.actual_companies = sorted(c.ticker for c in actual_result.companies)
|
|
comparison.companies_match = (
|
|
set(comparison.expected_companies) == set(comparison.actual_companies)
|
|
)
|
|
|
|
# Sentiment by ticker
|
|
comparison.expected_sentiment_map = {
|
|
c.ticker: c.sentiment for c in expected.companies
|
|
}
|
|
comparison.actual_sentiment_map = {
|
|
c.ticker: c.sentiment for c in actual_result.companies
|
|
}
|
|
comparison.sentiment_match = (
|
|
comparison.expected_sentiment_map == comparison.actual_sentiment_map
|
|
)
|
|
|
|
# Catalyst type by ticker
|
|
comparison.expected_catalyst_map = {
|
|
c.ticker: c.catalyst_type for c in expected.companies
|
|
}
|
|
comparison.actual_catalyst_map = {
|
|
c.ticker: c.catalyst_type for c in actual_result.companies
|
|
}
|
|
comparison.catalyst_match = (
|
|
comparison.expected_catalyst_map == comparison.actual_catalyst_map
|
|
)
|
|
|
|
# Schema validity of actual result
|
|
actual_report = validate_extraction(
|
|
actual_result.model_dump(mode="json"),
|
|
document_text=fixture.document_text,
|
|
)
|
|
comparison.actual_schema_valid = actual_report.valid
|
|
if actual_report.warnings:
|
|
comparison.warnings = actual_report.warnings
|
|
|
|
if not comparison.companies_match:
|
|
comparison.warnings.append(
|
|
f"company_mismatch: expected={comparison.expected_companies} actual={comparison.actual_companies}"
|
|
)
|
|
|
|
return comparison
|