Files
stonks-oracle/tests/test_event_classifier.py
T
Celes Renata c501ccea40 fix: default model to qwen3.5:9b + improve event classifier prompt
- Migration 026 and OllamaConfig now default to qwen3.5:9b instead of
  llama3.1:8b. Existing deployments keep their current model (qwen3.5:9b-fast)
  since the migration uses WHERE NOT EXISTS on slug.

- Event classifier system prompt expanded with macro-vs-company filtering:
  explicitly instructs the model to NOT classify single-company news
  (lawsuits, earnings, management changes, debt crises) as macro events.
  Sets severity=low and confidence<0.3 for company-specific articles.
  Reserves 'critical' severity for multi-country/global market events.
  Prevents over-tagging event_types by requiring direct description.

- Updated test_system_prompt_is_concise threshold to accommodate the
  expanded prompt (300 → 1000 chars).
2026-04-17 02:53:38 +00:00

417 lines
15 KiB
Python

"""Tests for the event classifier module.
Covers GlobalEvent dataclass, JSON schema generation, prompt building,
response parsing/normalization, and the classify_global_event function.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
"""
from __future__ import annotations
import json
import uuid
from dataclasses import fields
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from services.extractor.event_classifier import (
GlobalEvent,
PROMPT_VERSION,
SCHEMA_VERSION,
_normalize_duration,
_normalize_event_types,
_normalize_severity,
_parse_classification_response,
build_event_classification_prompt,
classify_global_event,
get_event_json_schema,
persist_global_event,
)
from services.shared.schemas import ModelMetadata
# ---------------------------------------------------------------------------
# GlobalEvent dataclass tests
# ---------------------------------------------------------------------------
class TestGlobalEvent:
def test_default_construction(self):
event = GlobalEvent()
assert event.event_id # UUID generated
assert event.event_types == []
assert event.severity == "low"
assert event.affected_regions == []
assert event.affected_sectors == []
assert event.affected_commodities == []
assert event.summary == ""
assert event.key_facts == []
assert event.estimated_duration == "short_term"
assert event.confidence == 0.5
assert event.source_document_id == ""
assert isinstance(event.model_metadata, ModelMetadata)
def test_all_fields_present(self):
"""Verify all design-specified fields exist on GlobalEvent."""
field_names = {f.name for f in fields(GlobalEvent)}
expected = {
"event_id", "event_types", "severity", "affected_regions",
"affected_sectors", "affected_commodities", "summary",
"key_facts", "estimated_duration", "confidence",
"source_document_id", "model_metadata",
}
assert expected == field_names
def test_custom_construction(self):
event = GlobalEvent(
event_id="test-id",
event_types=["trade_barrier", "cost_increase"],
severity="high",
affected_regions=["US", "CN"],
affected_sectors=["Industrials"],
affected_commodities=["steel"],
summary="Trade war escalation",
key_facts=["25% tariff announced"],
estimated_duration="medium_term",
confidence=0.85,
source_document_id="doc-123",
)
assert event.event_types == ["trade_barrier", "cost_increase"]
assert event.severity == "high"
assert event.confidence == 0.85
def test_unique_event_ids(self):
e1 = GlobalEvent()
e2 = GlobalEvent()
assert e1.event_id != e2.event_id
# ---------------------------------------------------------------------------
# JSON schema tests
# ---------------------------------------------------------------------------
class TestEventJsonSchema:
def test_schema_is_valid_json_schema(self):
schema = get_event_json_schema()
assert schema["type"] == "object"
assert "properties" in schema
assert "required" in schema
def test_schema_has_all_required_fields(self):
schema = get_event_json_schema()
required = set(schema["required"])
expected = {
"event_types", "severity", "affected_regions",
"affected_sectors", "affected_commodities", "summary",
"key_facts", "estimated_duration", "confidence",
}
assert expected == required
def test_schema_event_types_has_enum(self):
schema = get_event_json_schema()
items = schema["properties"]["event_types"]["items"]
assert "enum" in items
assert "supply_disruption" in items["enum"]
assert "geopolitical_risk" in items["enum"]
def test_schema_severity_has_enum(self):
schema = get_event_json_schema()
severity = schema["properties"]["severity"]
assert set(severity["enum"]) == {"low", "moderate", "high", "critical"}
def test_schema_duration_has_enum(self):
schema = get_event_json_schema()
duration = schema["properties"]["estimated_duration"]
assert set(duration["enum"]) == {"short_term", "medium_term", "long_term"}
def test_schema_confidence_bounds(self):
schema = get_event_json_schema()
conf = schema["properties"]["confidence"]
assert conf["minimum"] == 0.0
assert conf["maximum"] == 1.0
def test_schema_no_additional_properties(self):
schema = get_event_json_schema()
assert schema.get("additionalProperties") is False
# ---------------------------------------------------------------------------
# Prompt builder tests
# ---------------------------------------------------------------------------
class TestBuildEventClassificationPrompt:
def test_returns_system_and_user(self):
result = build_event_classification_prompt("Some article text")
assert "system" in result
assert "user" in result
def test_user_prompt_contains_article_text(self):
result = build_event_classification_prompt("Tariffs announced on steel imports")
assert "Tariffs announced on steel imports" in result["user"]
def test_user_prompt_contains_anti_hallucination_rules(self):
result = build_event_classification_prompt("text")
assert "Do NOT infer" in result["user"]
assert "fabricate" in result["user"]
def test_system_prompt_is_concise(self):
result = build_event_classification_prompt("text")
assert "JSON" in result["system"]
assert len(result["system"]) < 1000 # expanded to include macro-vs-company filtering rules
def test_user_prompt_lists_impact_types(self):
result = build_event_classification_prompt("text")
assert "supply_disruption" in result["user"]
assert "geopolitical_risk" in result["user"]
# ---------------------------------------------------------------------------
# Normalization tests
# ---------------------------------------------------------------------------
class TestNormalization:
def test_normalize_event_types_valid(self):
assert _normalize_event_types(["trade_barrier", "cost_increase"]) == [
"trade_barrier", "cost_increase",
]
def test_normalize_event_types_filters_invalid(self):
result = _normalize_event_types(["trade_barrier", "invalid_type", "cost_increase"])
assert result == ["trade_barrier", "cost_increase"]
def test_normalize_event_types_empty_fallback(self):
assert _normalize_event_types([]) == ["geopolitical_risk"]
assert _normalize_event_types(["bogus"]) == ["geopolitical_risk"]
def test_normalize_severity_valid(self):
assert _normalize_severity("high") == "high"
assert _normalize_severity("CRITICAL") == "critical"
def test_normalize_severity_invalid_fallback(self):
assert _normalize_severity("extreme") == "low"
def test_normalize_duration_valid(self):
assert _normalize_duration("medium_term") == "medium_term"
def test_normalize_duration_invalid_fallback(self):
assert _normalize_duration("forever") == "short_term"
# ---------------------------------------------------------------------------
# Parse classification response tests
# ---------------------------------------------------------------------------
class TestParseClassificationResponse:
def _make_raw_json(self, **overrides) -> str:
data = {
"event_types": ["trade_barrier"],
"severity": "high",
"affected_regions": ["US", "CN"],
"affected_sectors": ["Industrials"],
"affected_commodities": ["steel"],
"summary": "New tariffs on steel imports",
"key_facts": ["25% tariff effective immediately"],
"estimated_duration": "medium_term",
"confidence": 0.8,
}
data.update(overrides)
return json.dumps(data)
def test_basic_parse(self):
event = _parse_classification_response(
self._make_raw_json(), "doc-1", "llama3.1:8b",
)
assert event.event_types == ["trade_barrier"]
assert event.severity == "high"
assert event.affected_regions == ["US", "CN"]
assert event.summary == "New tariffs on steel imports"
assert event.source_document_id == "doc-1"
assert event.model_metadata.model_name == "llama3.1:8b"
assert event.model_metadata.prompt_version == PROMPT_VERSION
def test_multiple_event_types_preserved(self):
"""Requirement 2.4: multiple impact types not collapsed."""
raw = self._make_raw_json(
event_types=["trade_barrier", "cost_increase", "supply_disruption"],
)
event = _parse_classification_response(raw, "doc-1", "model")
assert len(event.event_types) == 3
assert "trade_barrier" in event.event_types
assert "cost_increase" in event.event_types
assert "supply_disruption" in event.event_types
def test_confidence_clamped(self):
raw = self._make_raw_json(confidence=1.5)
event = _parse_classification_response(raw, "doc-1", "model")
assert event.confidence == 1.0
raw = self._make_raw_json(confidence=-0.3)
event = _parse_classification_response(raw, "doc-1", "model")
assert event.confidence == 0.0
def test_invalid_severity_normalized(self):
raw = self._make_raw_json(severity="extreme")
event = _parse_classification_response(raw, "doc-1", "model")
assert event.severity == "low"
def test_invalid_duration_normalized(self):
raw = self._make_raw_json(estimated_duration="permanent")
event = _parse_classification_response(raw, "doc-1", "model")
assert event.estimated_duration == "short_term"
def test_event_id_is_uuid(self):
event = _parse_classification_response(
self._make_raw_json(), "doc-1", "model",
)
uuid.UUID(event.event_id) # Should not raise
# ---------------------------------------------------------------------------
# classify_global_event tests
# ---------------------------------------------------------------------------
class TestClassifyGlobalEvent:
def _make_mock_client(self, raw_output: str, error: str | None = None):
"""Create a mock OllamaClient with configurable response."""
client = MagicMock()
client._config = MagicMock()
client._config.model = "llama3.1:8b"
client._max_retries = 2
client._base_delay = 0.01
client._max_delay = 0.1
client._backoff_multiplier = 2.0
attempt = MagicMock()
attempt.raw_output = raw_output
attempt.error = error
client._call_ollama = AsyncMock(return_value=attempt)
return client
@pytest.mark.asyncio
async def test_successful_classification(self):
raw = json.dumps({
"event_types": ["commodity_shock"],
"severity": "critical",
"affected_regions": ["Global"],
"affected_sectors": ["Energy"],
"affected_commodities": ["crude_oil"],
"summary": "OPEC cuts production",
"key_facts": ["2M barrel/day cut"],
"estimated_duration": "medium_term",
"confidence": 0.9,
})
client = self._make_mock_client(raw)
event = await classify_global_event(
"OPEC announced production cuts...",
"doc-123",
client,
)
assert event.event_types == ["commodity_shock"]
assert event.severity == "critical"
assert event.confidence == 0.9
assert event.source_document_id == "doc-123"
client._call_ollama.assert_called_once()
@pytest.mark.asyncio
async def test_retries_on_error(self):
"""Requirement 2.3: retries on invalid response."""
good_raw = json.dumps({
"event_types": ["geopolitical_risk"],
"severity": "high",
"affected_regions": ["UA", "RU"],
"affected_sectors": ["Energy"],
"affected_commodities": ["natural_gas"],
"summary": "Conflict escalation",
"key_facts": ["Military action reported"],
"estimated_duration": "long_term",
"confidence": 0.7,
})
fail_attempt = MagicMock()
fail_attempt.raw_output = ""
fail_attempt.error = "timeout"
success_attempt = MagicMock()
success_attempt.raw_output = good_raw
success_attempt.error = None
client = self._make_mock_client("")
client._call_ollama = AsyncMock(side_effect=[fail_attempt, success_attempt])
event = await classify_global_event("text", "doc-456", client)
assert event.severity == "high"
assert client._call_ollama.call_count == 2
@pytest.mark.asyncio
async def test_raises_after_exhausted_retries(self):
fail_attempt = MagicMock()
fail_attempt.raw_output = ""
fail_attempt.error = "timeout"
client = self._make_mock_client("")
client._call_ollama = AsyncMock(return_value=fail_attempt)
with pytest.raises(ValueError, match="Event classification failed"):
await classify_global_event("text", "doc-789", client)
assert client._call_ollama.call_count == 3 # initial + 2 retries
@pytest.mark.asyncio
async def test_minio_persistence_called(self):
raw = json.dumps({
"event_types": ["regulatory_pressure"],
"severity": "moderate",
"affected_regions": ["EU"],
"affected_sectors": ["Information Technology"],
"affected_commodities": [],
"summary": "New AI regulation",
"key_facts": ["EU AI Act enforcement begins"],
"estimated_duration": "long_term",
"confidence": 0.75,
})
client = self._make_mock_client(raw)
minio = MagicMock()
minio.put_object = MagicMock()
event = await classify_global_event(
"text", "doc-abc", client, minio_client=minio,
)
assert event.severity == "moderate"
# put_object called for prompt + result
assert minio.put_object.call_count == 2
@pytest.mark.asyncio
async def test_pg_persistence_called(self):
raw = json.dumps({
"event_types": ["currency_impact"],
"severity": "low",
"affected_regions": ["JP"],
"affected_sectors": ["Financials"],
"affected_commodities": [],
"summary": "Yen weakens",
"key_facts": ["USD/JPY hits 160"],
"estimated_duration": "short_term",
"confidence": 0.6,
})
client = self._make_mock_client(raw)
pool = MagicMock()
pool.fetchval = AsyncMock(return_value=uuid.uuid4())
event = await classify_global_event(
"text", "doc-def", client, pool=pool,
)
assert event.event_types == ["currency_impact"]
pool.fetchval.assert_called_once()
# Verify the SQL contains global_events
call_args = pool.fetchval.call_args
assert "global_events" in call_args[0][0]