c501ccea40
- Migration 026 and OllamaConfig now default to qwen3.5:9b instead of llama3.1:8b. Existing deployments keep their current model (qwen3.5:9b-fast) since the migration uses WHERE NOT EXISTS on slug. - Event classifier system prompt expanded with macro-vs-company filtering: explicitly instructs the model to NOT classify single-company news (lawsuits, earnings, management changes, debt crises) as macro events. Sets severity=low and confidence<0.3 for company-specific articles. Reserves 'critical' severity for multi-country/global market events. Prevents over-tagging event_types by requiring direct description. - Updated test_system_prompt_is_concise threshold to accommodate the expanded prompt (300 → 1000 chars).
417 lines
15 KiB
Python
417 lines
15 KiB
Python
"""Tests for the event classifier module.
|
|
|
|
Covers GlobalEvent dataclass, JSON schema generation, prompt building,
|
|
response parsing/normalization, and the classify_global_event function.
|
|
|
|
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from dataclasses import fields
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from services.extractor.event_classifier import (
|
|
GlobalEvent,
|
|
PROMPT_VERSION,
|
|
SCHEMA_VERSION,
|
|
_normalize_duration,
|
|
_normalize_event_types,
|
|
_normalize_severity,
|
|
_parse_classification_response,
|
|
build_event_classification_prompt,
|
|
classify_global_event,
|
|
get_event_json_schema,
|
|
persist_global_event,
|
|
)
|
|
from services.shared.schemas import ModelMetadata
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GlobalEvent dataclass tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGlobalEvent:
|
|
def test_default_construction(self):
|
|
event = GlobalEvent()
|
|
assert event.event_id # UUID generated
|
|
assert event.event_types == []
|
|
assert event.severity == "low"
|
|
assert event.affected_regions == []
|
|
assert event.affected_sectors == []
|
|
assert event.affected_commodities == []
|
|
assert event.summary == ""
|
|
assert event.key_facts == []
|
|
assert event.estimated_duration == "short_term"
|
|
assert event.confidence == 0.5
|
|
assert event.source_document_id == ""
|
|
assert isinstance(event.model_metadata, ModelMetadata)
|
|
|
|
def test_all_fields_present(self):
|
|
"""Verify all design-specified fields exist on GlobalEvent."""
|
|
field_names = {f.name for f in fields(GlobalEvent)}
|
|
expected = {
|
|
"event_id", "event_types", "severity", "affected_regions",
|
|
"affected_sectors", "affected_commodities", "summary",
|
|
"key_facts", "estimated_duration", "confidence",
|
|
"source_document_id", "model_metadata",
|
|
}
|
|
assert expected == field_names
|
|
|
|
def test_custom_construction(self):
|
|
event = GlobalEvent(
|
|
event_id="test-id",
|
|
event_types=["trade_barrier", "cost_increase"],
|
|
severity="high",
|
|
affected_regions=["US", "CN"],
|
|
affected_sectors=["Industrials"],
|
|
affected_commodities=["steel"],
|
|
summary="Trade war escalation",
|
|
key_facts=["25% tariff announced"],
|
|
estimated_duration="medium_term",
|
|
confidence=0.85,
|
|
source_document_id="doc-123",
|
|
)
|
|
assert event.event_types == ["trade_barrier", "cost_increase"]
|
|
assert event.severity == "high"
|
|
assert event.confidence == 0.85
|
|
|
|
def test_unique_event_ids(self):
|
|
e1 = GlobalEvent()
|
|
e2 = GlobalEvent()
|
|
assert e1.event_id != e2.event_id
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# JSON schema tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEventJsonSchema:
|
|
def test_schema_is_valid_json_schema(self):
|
|
schema = get_event_json_schema()
|
|
assert schema["type"] == "object"
|
|
assert "properties" in schema
|
|
assert "required" in schema
|
|
|
|
def test_schema_has_all_required_fields(self):
|
|
schema = get_event_json_schema()
|
|
required = set(schema["required"])
|
|
expected = {
|
|
"event_types", "severity", "affected_regions",
|
|
"affected_sectors", "affected_commodities", "summary",
|
|
"key_facts", "estimated_duration", "confidence",
|
|
}
|
|
assert expected == required
|
|
|
|
def test_schema_event_types_has_enum(self):
|
|
schema = get_event_json_schema()
|
|
items = schema["properties"]["event_types"]["items"]
|
|
assert "enum" in items
|
|
assert "supply_disruption" in items["enum"]
|
|
assert "geopolitical_risk" in items["enum"]
|
|
|
|
def test_schema_severity_has_enum(self):
|
|
schema = get_event_json_schema()
|
|
severity = schema["properties"]["severity"]
|
|
assert set(severity["enum"]) == {"low", "moderate", "high", "critical"}
|
|
|
|
def test_schema_duration_has_enum(self):
|
|
schema = get_event_json_schema()
|
|
duration = schema["properties"]["estimated_duration"]
|
|
assert set(duration["enum"]) == {"short_term", "medium_term", "long_term"}
|
|
|
|
def test_schema_confidence_bounds(self):
|
|
schema = get_event_json_schema()
|
|
conf = schema["properties"]["confidence"]
|
|
assert conf["minimum"] == 0.0
|
|
assert conf["maximum"] == 1.0
|
|
|
|
def test_schema_no_additional_properties(self):
|
|
schema = get_event_json_schema()
|
|
assert schema.get("additionalProperties") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prompt builder tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildEventClassificationPrompt:
|
|
def test_returns_system_and_user(self):
|
|
result = build_event_classification_prompt("Some article text")
|
|
assert "system" in result
|
|
assert "user" in result
|
|
|
|
def test_user_prompt_contains_article_text(self):
|
|
result = build_event_classification_prompt("Tariffs announced on steel imports")
|
|
assert "Tariffs announced on steel imports" in result["user"]
|
|
|
|
def test_user_prompt_contains_anti_hallucination_rules(self):
|
|
result = build_event_classification_prompt("text")
|
|
assert "Do NOT infer" in result["user"]
|
|
assert "fabricate" in result["user"]
|
|
|
|
def test_system_prompt_is_concise(self):
|
|
result = build_event_classification_prompt("text")
|
|
assert "JSON" in result["system"]
|
|
assert len(result["system"]) < 1000 # expanded to include macro-vs-company filtering rules
|
|
|
|
def test_user_prompt_lists_impact_types(self):
|
|
result = build_event_classification_prompt("text")
|
|
assert "supply_disruption" in result["user"]
|
|
assert "geopolitical_risk" in result["user"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Normalization tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestNormalization:
|
|
def test_normalize_event_types_valid(self):
|
|
assert _normalize_event_types(["trade_barrier", "cost_increase"]) == [
|
|
"trade_barrier", "cost_increase",
|
|
]
|
|
|
|
def test_normalize_event_types_filters_invalid(self):
|
|
result = _normalize_event_types(["trade_barrier", "invalid_type", "cost_increase"])
|
|
assert result == ["trade_barrier", "cost_increase"]
|
|
|
|
def test_normalize_event_types_empty_fallback(self):
|
|
assert _normalize_event_types([]) == ["geopolitical_risk"]
|
|
assert _normalize_event_types(["bogus"]) == ["geopolitical_risk"]
|
|
|
|
def test_normalize_severity_valid(self):
|
|
assert _normalize_severity("high") == "high"
|
|
assert _normalize_severity("CRITICAL") == "critical"
|
|
|
|
def test_normalize_severity_invalid_fallback(self):
|
|
assert _normalize_severity("extreme") == "low"
|
|
|
|
def test_normalize_duration_valid(self):
|
|
assert _normalize_duration("medium_term") == "medium_term"
|
|
|
|
def test_normalize_duration_invalid_fallback(self):
|
|
assert _normalize_duration("forever") == "short_term"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Parse classification response tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestParseClassificationResponse:
|
|
def _make_raw_json(self, **overrides) -> str:
|
|
data = {
|
|
"event_types": ["trade_barrier"],
|
|
"severity": "high",
|
|
"affected_regions": ["US", "CN"],
|
|
"affected_sectors": ["Industrials"],
|
|
"affected_commodities": ["steel"],
|
|
"summary": "New tariffs on steel imports",
|
|
"key_facts": ["25% tariff effective immediately"],
|
|
"estimated_duration": "medium_term",
|
|
"confidence": 0.8,
|
|
}
|
|
data.update(overrides)
|
|
return json.dumps(data)
|
|
|
|
def test_basic_parse(self):
|
|
event = _parse_classification_response(
|
|
self._make_raw_json(), "doc-1", "llama3.1:8b",
|
|
)
|
|
assert event.event_types == ["trade_barrier"]
|
|
assert event.severity == "high"
|
|
assert event.affected_regions == ["US", "CN"]
|
|
assert event.summary == "New tariffs on steel imports"
|
|
assert event.source_document_id == "doc-1"
|
|
assert event.model_metadata.model_name == "llama3.1:8b"
|
|
assert event.model_metadata.prompt_version == PROMPT_VERSION
|
|
|
|
def test_multiple_event_types_preserved(self):
|
|
"""Requirement 2.4: multiple impact types not collapsed."""
|
|
raw = self._make_raw_json(
|
|
event_types=["trade_barrier", "cost_increase", "supply_disruption"],
|
|
)
|
|
event = _parse_classification_response(raw, "doc-1", "model")
|
|
assert len(event.event_types) == 3
|
|
assert "trade_barrier" in event.event_types
|
|
assert "cost_increase" in event.event_types
|
|
assert "supply_disruption" in event.event_types
|
|
|
|
def test_confidence_clamped(self):
|
|
raw = self._make_raw_json(confidence=1.5)
|
|
event = _parse_classification_response(raw, "doc-1", "model")
|
|
assert event.confidence == 1.0
|
|
|
|
raw = self._make_raw_json(confidence=-0.3)
|
|
event = _parse_classification_response(raw, "doc-1", "model")
|
|
assert event.confidence == 0.0
|
|
|
|
def test_invalid_severity_normalized(self):
|
|
raw = self._make_raw_json(severity="extreme")
|
|
event = _parse_classification_response(raw, "doc-1", "model")
|
|
assert event.severity == "low"
|
|
|
|
def test_invalid_duration_normalized(self):
|
|
raw = self._make_raw_json(estimated_duration="permanent")
|
|
event = _parse_classification_response(raw, "doc-1", "model")
|
|
assert event.estimated_duration == "short_term"
|
|
|
|
def test_event_id_is_uuid(self):
|
|
event = _parse_classification_response(
|
|
self._make_raw_json(), "doc-1", "model",
|
|
)
|
|
uuid.UUID(event.event_id) # Should not raise
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# classify_global_event tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestClassifyGlobalEvent:
|
|
def _make_mock_client(self, raw_output: str, error: str | None = None):
|
|
"""Create a mock OllamaClient with configurable response."""
|
|
client = MagicMock()
|
|
client._config = MagicMock()
|
|
client._config.model = "llama3.1:8b"
|
|
client._max_retries = 2
|
|
client._base_delay = 0.01
|
|
client._max_delay = 0.1
|
|
client._backoff_multiplier = 2.0
|
|
|
|
attempt = MagicMock()
|
|
attempt.raw_output = raw_output
|
|
attempt.error = error
|
|
client._call_ollama = AsyncMock(return_value=attempt)
|
|
return client
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_classification(self):
|
|
raw = json.dumps({
|
|
"event_types": ["commodity_shock"],
|
|
"severity": "critical",
|
|
"affected_regions": ["Global"],
|
|
"affected_sectors": ["Energy"],
|
|
"affected_commodities": ["crude_oil"],
|
|
"summary": "OPEC cuts production",
|
|
"key_facts": ["2M barrel/day cut"],
|
|
"estimated_duration": "medium_term",
|
|
"confidence": 0.9,
|
|
})
|
|
client = self._make_mock_client(raw)
|
|
|
|
event = await classify_global_event(
|
|
"OPEC announced production cuts...",
|
|
"doc-123",
|
|
client,
|
|
)
|
|
|
|
assert event.event_types == ["commodity_shock"]
|
|
assert event.severity == "critical"
|
|
assert event.confidence == 0.9
|
|
assert event.source_document_id == "doc-123"
|
|
client._call_ollama.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retries_on_error(self):
|
|
"""Requirement 2.3: retries on invalid response."""
|
|
good_raw = json.dumps({
|
|
"event_types": ["geopolitical_risk"],
|
|
"severity": "high",
|
|
"affected_regions": ["UA", "RU"],
|
|
"affected_sectors": ["Energy"],
|
|
"affected_commodities": ["natural_gas"],
|
|
"summary": "Conflict escalation",
|
|
"key_facts": ["Military action reported"],
|
|
"estimated_duration": "long_term",
|
|
"confidence": 0.7,
|
|
})
|
|
|
|
fail_attempt = MagicMock()
|
|
fail_attempt.raw_output = ""
|
|
fail_attempt.error = "timeout"
|
|
|
|
success_attempt = MagicMock()
|
|
success_attempt.raw_output = good_raw
|
|
success_attempt.error = None
|
|
|
|
client = self._make_mock_client("")
|
|
client._call_ollama = AsyncMock(side_effect=[fail_attempt, success_attempt])
|
|
|
|
event = await classify_global_event("text", "doc-456", client)
|
|
assert event.severity == "high"
|
|
assert client._call_ollama.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_raises_after_exhausted_retries(self):
|
|
fail_attempt = MagicMock()
|
|
fail_attempt.raw_output = ""
|
|
fail_attempt.error = "timeout"
|
|
|
|
client = self._make_mock_client("")
|
|
client._call_ollama = AsyncMock(return_value=fail_attempt)
|
|
|
|
with pytest.raises(ValueError, match="Event classification failed"):
|
|
await classify_global_event("text", "doc-789", client)
|
|
|
|
assert client._call_ollama.call_count == 3 # initial + 2 retries
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_minio_persistence_called(self):
|
|
raw = json.dumps({
|
|
"event_types": ["regulatory_pressure"],
|
|
"severity": "moderate",
|
|
"affected_regions": ["EU"],
|
|
"affected_sectors": ["Information Technology"],
|
|
"affected_commodities": [],
|
|
"summary": "New AI regulation",
|
|
"key_facts": ["EU AI Act enforcement begins"],
|
|
"estimated_duration": "long_term",
|
|
"confidence": 0.75,
|
|
})
|
|
client = self._make_mock_client(raw)
|
|
minio = MagicMock()
|
|
minio.put_object = MagicMock()
|
|
|
|
event = await classify_global_event(
|
|
"text", "doc-abc", client, minio_client=minio,
|
|
)
|
|
|
|
assert event.severity == "moderate"
|
|
# put_object called for prompt + result
|
|
assert minio.put_object.call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pg_persistence_called(self):
|
|
raw = json.dumps({
|
|
"event_types": ["currency_impact"],
|
|
"severity": "low",
|
|
"affected_regions": ["JP"],
|
|
"affected_sectors": ["Financials"],
|
|
"affected_commodities": [],
|
|
"summary": "Yen weakens",
|
|
"key_facts": ["USD/JPY hits 160"],
|
|
"estimated_duration": "short_term",
|
|
"confidence": 0.6,
|
|
})
|
|
client = self._make_mock_client(raw)
|
|
pool = MagicMock()
|
|
pool.fetchval = AsyncMock(return_value=uuid.uuid4())
|
|
|
|
event = await classify_global_event(
|
|
"text", "doc-def", client, pool=pool,
|
|
)
|
|
|
|
assert event.event_types == ["currency_impact"]
|
|
pool.fetchval.assert_called_once()
|
|
# Verify the SQL contains global_events
|
|
call_args = pool.fetchval.call_args
|
|
assert "global_events" in call_args[0][0]
|