201 lines
6.5 KiB
Python
201 lines
6.5 KiB
Python
"""Tests for the extraction worker persistence logic.
|
|
|
|
Validates that persist_extraction correctly uploads artifacts to MinIO
|
|
and persists intelligence/impact records to PostgreSQL.
|
|
|
|
Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 9.1, 9.2
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from services.extractor.client import ExtractionAttempt, ExtractionResponse
|
|
from services.extractor.schemas import ExtractionResult, ValidationReport
|
|
from services.extractor.worker import persist_extraction
|
|
|
|
|
|
def _make_valid_result() -> ExtractionResult:
|
|
"""Build a minimal valid ExtractionResult."""
|
|
return ExtractionResult.model_validate({
|
|
"summary": "Apple beat earnings expectations.",
|
|
"companies": [
|
|
{
|
|
"ticker": "AAPL",
|
|
"company_name": "Apple Inc.",
|
|
"relevance": 0.95,
|
|
"sentiment": "positive",
|
|
"impact_score": 0.7,
|
|
"impact_horizon": "1d_30d",
|
|
"catalyst_type": "earnings",
|
|
"key_facts": ["Revenue up 12%"],
|
|
"risks": [],
|
|
"evidence_spans": ["Apple beat expectations"],
|
|
}
|
|
],
|
|
"macro_themes": ["ai_capex"],
|
|
"novelty_score": 0.6,
|
|
"confidence": 0.85,
|
|
"extraction_warnings": [],
|
|
})
|
|
|
|
|
|
def _make_success_response() -> ExtractionResponse:
|
|
"""Build a successful ExtractionResponse with one attempt."""
|
|
result = _make_valid_result()
|
|
validation = ValidationReport(valid=True, errors=[], warnings=[], parsed=result)
|
|
attempt = ExtractionAttempt(
|
|
raw_output=result.model_dump_json(),
|
|
validation=validation,
|
|
error=None,
|
|
duration_ms=500,
|
|
model="test-model",
|
|
)
|
|
return ExtractionResponse(
|
|
success=True,
|
|
result=result,
|
|
attempts=[attempt],
|
|
prompt_metadata={"prompt_version": "document-intel-v1", "schema_version": "2.0.0"},
|
|
model="test-model",
|
|
total_duration_ms=500,
|
|
)
|
|
|
|
|
|
def _make_failed_response() -> ExtractionResponse:
|
|
"""Build a failed ExtractionResponse with two attempts."""
|
|
attempt1 = ExtractionAttempt(
|
|
raw_output="bad json",
|
|
validation=None,
|
|
error="invalid_json",
|
|
duration_ms=200,
|
|
model="test-model",
|
|
)
|
|
attempt2 = ExtractionAttempt(
|
|
raw_output="still bad",
|
|
validation=ValidationReport(valid=False, errors=["schema_fail"], warnings=[]),
|
|
error="schema_fail",
|
|
duration_ms=300,
|
|
model="test-model",
|
|
)
|
|
return ExtractionResponse(
|
|
success=False,
|
|
result=None,
|
|
attempts=[attempt1, attempt2],
|
|
prompt_metadata={"prompt_version": "document-intel-v1", "schema_version": "2.0.0"},
|
|
model="test-model",
|
|
total_duration_ms=500,
|
|
)
|
|
|
|
|
|
def _mock_pool(intel_id: str = "intel-uuid-1", impact_id: str = "impact-uuid-1") -> AsyncMock:
|
|
"""Create a mock asyncpg pool that returns predictable UUIDs."""
|
|
pool = AsyncMock()
|
|
# Side effects: intelligence insert, impact insert, metrics insert
|
|
pool.fetchval = AsyncMock(side_effect=[intel_id, impact_id, "metrics-uuid-1"])
|
|
pool.execute = AsyncMock()
|
|
return pool
|
|
|
|
|
|
def _mock_minio() -> MagicMock:
|
|
"""Create a mock MinIO client."""
|
|
client = MagicMock()
|
|
return client
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_successful_extraction():
|
|
"""Successful extraction persists all artifacts and intelligence records."""
|
|
pool = _mock_pool()
|
|
minio = _mock_minio()
|
|
response = _make_success_response()
|
|
ts = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc)
|
|
|
|
result = await persist_extraction(
|
|
pool=pool,
|
|
minio_client=minio,
|
|
document_id="doc-123",
|
|
ticker="AAPL",
|
|
extraction_response=response,
|
|
company_id_map={"AAPL": "company-uuid-1"},
|
|
source_credibility=0.8,
|
|
timestamp=ts,
|
|
)
|
|
|
|
assert result.success
|
|
assert result.intelligence_id == "intel-uuid-1"
|
|
assert result.impact_ids == ["impact-uuid-1"]
|
|
assert result.prompt_ref is not None
|
|
assert "stonks-llm-prompts" in result.prompt_ref
|
|
assert result.raw_output_ref is not None
|
|
assert "stonks-llm-results" in result.raw_output_ref
|
|
assert result.validation_ref is not None
|
|
assert result.intelligence_ref is not None
|
|
|
|
# MinIO should have 4 uploads: prompt, raw output, validation, intelligence
|
|
assert minio.put_object.call_count == 4
|
|
|
|
# PostgreSQL: 1 intelligence insert + 1 impact insert + 1 metrics insert + 1 status update
|
|
assert pool.fetchval.call_count == 3
|
|
assert pool.execute.call_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_failed_extraction():
|
|
"""Failed extraction still persists attempt data and marks document as failed."""
|
|
pool = AsyncMock()
|
|
pool.fetchval = AsyncMock(return_value="intel-uuid-fail")
|
|
pool.execute = AsyncMock()
|
|
minio = _mock_minio()
|
|
response = _make_failed_response()
|
|
ts = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc)
|
|
|
|
result = await persist_extraction(
|
|
pool=pool,
|
|
minio_client=minio,
|
|
document_id="doc-456",
|
|
ticker="AAPL",
|
|
extraction_response=response,
|
|
timestamp=ts,
|
|
)
|
|
|
|
assert not result.success
|
|
assert result.intelligence_id == "intel-uuid-fail"
|
|
assert result.intelligence_ref is None # no final intelligence on failure
|
|
assert result.prompt_ref is not None
|
|
assert result.raw_output_ref is not None
|
|
assert result.validation_ref is not None
|
|
|
|
# MinIO: 3 uploads (prompt, raw output, validation — no intelligence)
|
|
assert minio.put_object.call_count == 3
|
|
|
|
# PostgreSQL: 1 intelligence insert + 1 metrics insert + 1 status update
|
|
assert pool.fetchval.call_count == 2
|
|
assert pool.execute.call_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_skips_impact_without_company_id():
|
|
"""Impact records are skipped when company_id_map doesn't have the ticker."""
|
|
pool = AsyncMock()
|
|
pool.fetchval = AsyncMock(return_value="intel-uuid-2")
|
|
pool.execute = AsyncMock()
|
|
minio = _mock_minio()
|
|
response = _make_success_response()
|
|
|
|
result = await persist_extraction(
|
|
pool=pool,
|
|
minio_client=minio,
|
|
document_id="doc-789",
|
|
ticker="AAPL",
|
|
extraction_response=response,
|
|
company_id_map={}, # no mapping for AAPL
|
|
)
|
|
|
|
assert result.success
|
|
assert result.impact_ids == []
|
|
# 1 fetchval for intelligence + 1 for metrics, no impact insert
|
|
assert pool.fetchval.call_count == 2
|