"""Tests for the event classifier module. Covers GlobalEvent dataclass, JSON schema generation, prompt building, response parsing/normalization, and the classify_global_event function. Requirements: 2.1, 2.2, 2.3, 2.4, 2.5 """ from __future__ import annotations import json import uuid from dataclasses import fields from unittest.mock import AsyncMock, MagicMock import pytest from services.extractor.event_classifier import ( PROMPT_VERSION, GlobalEvent, _normalize_duration, _normalize_event_types, _normalize_severity, _parse_classification_response, build_event_classification_prompt, classify_global_event, get_event_json_schema, ) from services.shared.schemas import ModelMetadata # --------------------------------------------------------------------------- # GlobalEvent dataclass tests # --------------------------------------------------------------------------- class TestGlobalEvent: def test_default_construction(self): event = GlobalEvent() assert event.event_id # UUID generated assert event.event_types == [] assert event.severity == "low" assert event.affected_regions == [] assert event.affected_sectors == [] assert event.affected_commodities == [] assert event.summary == "" assert event.key_facts == [] assert event.estimated_duration == "short_term" assert event.confidence == 0.5 assert event.source_document_id == "" assert isinstance(event.model_metadata, ModelMetadata) def test_all_fields_present(self): """Verify all design-specified fields exist on GlobalEvent.""" field_names = {f.name for f in fields(GlobalEvent)} expected = { "event_id", "event_types", "severity", "affected_regions", "affected_sectors", "affected_commodities", "summary", "key_facts", "estimated_duration", "confidence", "source_document_id", "model_metadata", } assert expected == field_names def test_custom_construction(self): event = GlobalEvent( event_id="test-id", event_types=["trade_barrier", "cost_increase"], severity="high", affected_regions=["US", "CN"], affected_sectors=["Industrials"], affected_commodities=["steel"], summary="Trade war escalation", key_facts=["25% tariff announced"], estimated_duration="medium_term", confidence=0.85, source_document_id="doc-123", ) assert event.event_types == ["trade_barrier", "cost_increase"] assert event.severity == "high" assert event.confidence == 0.85 def test_unique_event_ids(self): e1 = GlobalEvent() e2 = GlobalEvent() assert e1.event_id != e2.event_id # --------------------------------------------------------------------------- # JSON schema tests # --------------------------------------------------------------------------- class TestEventJsonSchema: def test_schema_is_valid_json_schema(self): schema = get_event_json_schema() assert schema["type"] == "object" assert "properties" in schema assert "required" in schema def test_schema_has_all_required_fields(self): schema = get_event_json_schema() required = set(schema["required"]) expected = { "event_types", "severity", "affected_regions", "affected_sectors", "affected_commodities", "summary", "key_facts", "estimated_duration", "confidence", } assert expected == required def test_schema_event_types_has_enum(self): schema = get_event_json_schema() items = schema["properties"]["event_types"]["items"] assert "enum" in items assert "supply_disruption" in items["enum"] assert "geopolitical_risk" in items["enum"] def test_schema_severity_has_enum(self): schema = get_event_json_schema() severity = schema["properties"]["severity"] assert set(severity["enum"]) == {"low", "moderate", "high", "critical"} def test_schema_duration_has_enum(self): schema = get_event_json_schema() duration = schema["properties"]["estimated_duration"] assert set(duration["enum"]) == {"short_term", "medium_term", "long_term"} def test_schema_confidence_bounds(self): schema = get_event_json_schema() conf = schema["properties"]["confidence"] assert conf["minimum"] == 0.0 assert conf["maximum"] == 1.0 def test_schema_no_additional_properties(self): schema = get_event_json_schema() assert schema.get("additionalProperties") is False # --------------------------------------------------------------------------- # Prompt builder tests # --------------------------------------------------------------------------- class TestBuildEventClassificationPrompt: def test_returns_system_and_user(self): result = build_event_classification_prompt("Some article text") assert "system" in result assert "user" in result def test_user_prompt_contains_article_text(self): result = build_event_classification_prompt("Tariffs announced on steel imports") assert "Tariffs announced on steel imports" in result["user"] def test_user_prompt_contains_anti_hallucination_rules(self): result = build_event_classification_prompt("text") assert "Do NOT infer" in result["user"] assert "fabricate" in result["user"] def test_system_prompt_is_concise(self): result = build_event_classification_prompt("text") assert "JSON" in result["system"] assert len(result["system"]) < 1000 # expanded to include macro-vs-company filtering rules def test_user_prompt_lists_impact_types(self): result = build_event_classification_prompt("text") assert "supply_disruption" in result["user"] assert "geopolitical_risk" in result["user"] # --------------------------------------------------------------------------- # Normalization tests # --------------------------------------------------------------------------- class TestNormalization: def test_normalize_event_types_valid(self): assert _normalize_event_types(["trade_barrier", "cost_increase"]) == [ "trade_barrier", "cost_increase", ] def test_normalize_event_types_filters_invalid(self): result = _normalize_event_types(["trade_barrier", "invalid_type", "cost_increase"]) assert result == ["trade_barrier", "cost_increase"] def test_normalize_event_types_empty_fallback(self): assert _normalize_event_types([]) == ["geopolitical_risk"] assert _normalize_event_types(["bogus"]) == ["geopolitical_risk"] def test_normalize_severity_valid(self): assert _normalize_severity("high") == "high" assert _normalize_severity("CRITICAL") == "critical" def test_normalize_severity_invalid_fallback(self): assert _normalize_severity("extreme") == "low" def test_normalize_duration_valid(self): assert _normalize_duration("medium_term") == "medium_term" def test_normalize_duration_invalid_fallback(self): assert _normalize_duration("forever") == "short_term" # --------------------------------------------------------------------------- # Parse classification response tests # --------------------------------------------------------------------------- class TestParseClassificationResponse: def _make_raw_json(self, **overrides) -> str: data = { "event_types": ["trade_barrier"], "severity": "high", "affected_regions": ["US", "CN"], "affected_sectors": ["Industrials"], "affected_commodities": ["steel"], "summary": "New tariffs on steel imports", "key_facts": ["25% tariff effective immediately"], "estimated_duration": "medium_term", "confidence": 0.8, } data.update(overrides) return json.dumps(data) def test_basic_parse(self): event = _parse_classification_response( self._make_raw_json(), "doc-1", "llama3.1:8b", ) assert event.event_types == ["trade_barrier"] assert event.severity == "high" assert event.affected_regions == ["US", "CN"] assert event.summary == "New tariffs on steel imports" assert event.source_document_id == "doc-1" assert event.model_metadata.model_name == "llama3.1:8b" assert event.model_metadata.prompt_version == PROMPT_VERSION def test_multiple_event_types_preserved(self): """Requirement 2.4: multiple impact types not collapsed.""" raw = self._make_raw_json( event_types=["trade_barrier", "cost_increase", "supply_disruption"], ) event = _parse_classification_response(raw, "doc-1", "model") assert len(event.event_types) == 3 assert "trade_barrier" in event.event_types assert "cost_increase" in event.event_types assert "supply_disruption" in event.event_types def test_confidence_clamped(self): raw = self._make_raw_json(confidence=1.5) event = _parse_classification_response(raw, "doc-1", "model") assert event.confidence == 1.0 raw = self._make_raw_json(confidence=-0.3) event = _parse_classification_response(raw, "doc-1", "model") assert event.confidence == 0.0 def test_invalid_severity_normalized(self): raw = self._make_raw_json(severity="extreme") event = _parse_classification_response(raw, "doc-1", "model") assert event.severity == "low" def test_invalid_duration_normalized(self): raw = self._make_raw_json(estimated_duration="permanent") event = _parse_classification_response(raw, "doc-1", "model") assert event.estimated_duration == "short_term" def test_event_id_is_uuid(self): event = _parse_classification_response( self._make_raw_json(), "doc-1", "model", ) uuid.UUID(event.event_id) # Should not raise # --------------------------------------------------------------------------- # classify_global_event tests # --------------------------------------------------------------------------- class TestClassifyGlobalEvent: def _make_mock_client(self, raw_output: str, error: str | None = None): """Create a mock LLMClient with configurable response.""" client = MagicMock() client._config = MagicMock() client._config.model = "llama3.1:8b" client._config.max_retries = 2 client._config.retry_base_delay = 0.01 client._config.retry_max_delay = 0.1 client._config.retry_backoff_multiplier = 2.0 attempt = MagicMock() attempt.raw_output = raw_output attempt.error = error client.call_llm = AsyncMock(return_value=attempt) return client @pytest.mark.asyncio async def test_successful_classification(self): raw = json.dumps({ "event_types": ["commodity_shock"], "severity": "critical", "affected_regions": ["Global"], "affected_sectors": ["Energy"], "affected_commodities": ["crude_oil"], "summary": "OPEC cuts production", "key_facts": ["2M barrel/day cut"], "estimated_duration": "medium_term", "confidence": 0.9, }) client = self._make_mock_client(raw) event = await classify_global_event( "OPEC announced production cuts...", "doc-123", client, ) assert event.event_types == ["commodity_shock"] assert event.severity == "critical" assert event.confidence == 0.9 assert event.source_document_id == "doc-123" client.call_llm.assert_called_once() @pytest.mark.asyncio async def test_retries_on_error(self): """Requirement 2.3: retries on invalid response.""" good_raw = json.dumps({ "event_types": ["geopolitical_risk"], "severity": "high", "affected_regions": ["UA", "RU"], "affected_sectors": ["Energy"], "affected_commodities": ["natural_gas"], "summary": "Conflict escalation", "key_facts": ["Military action reported"], "estimated_duration": "long_term", "confidence": 0.7, }) fail_attempt = MagicMock() fail_attempt.raw_output = "" fail_attempt.error = "timeout" success_attempt = MagicMock() success_attempt.raw_output = good_raw success_attempt.error = None client = self._make_mock_client("") client.call_llm = AsyncMock(side_effect=[fail_attempt, success_attempt]) event = await classify_global_event("text", "doc-456", client) assert event.severity == "high" assert client.call_llm.call_count == 2 @pytest.mark.asyncio async def test_raises_after_exhausted_retries(self): fail_attempt = MagicMock() fail_attempt.raw_output = "" fail_attempt.error = "timeout" client = self._make_mock_client("") client.call_llm = AsyncMock(return_value=fail_attempt) with pytest.raises(ValueError, match="Event classification failed"): await classify_global_event("text", "doc-789", client) assert client.call_llm.call_count == 3 # initial + 2 retries @pytest.mark.asyncio async def test_minio_persistence_called(self): raw = json.dumps({ "event_types": ["regulatory_pressure"], "severity": "moderate", "affected_regions": ["EU"], "affected_sectors": ["Information Technology"], "affected_commodities": [], "summary": "New AI regulation", "key_facts": ["EU AI Act enforcement begins"], "estimated_duration": "long_term", "confidence": 0.75, }) client = self._make_mock_client(raw) minio = MagicMock() minio.put_object = MagicMock() event = await classify_global_event( "text", "doc-abc", client, minio_client=minio, ) assert event.severity == "moderate" # put_object called for prompt + result assert minio.put_object.call_count == 2 @pytest.mark.asyncio async def test_pg_persistence_called(self): raw = json.dumps({ "event_types": ["currency_impact"], "severity": "low", "affected_regions": ["JP"], "affected_sectors": ["Financials"], "affected_commodities": [], "summary": "Yen weakens", "key_facts": ["USD/JPY hits 160"], "estimated_duration": "short_term", "confidence": 0.6, }) client = self._make_mock_client(raw) pool = MagicMock() pool.fetchval = AsyncMock(return_value=uuid.uuid4()) event = await classify_global_event( "text", "doc-def", client, pool=pool, ) assert event.event_types == ["currency_impact"] pool.fetchval.assert_called_once() # Verify the SQL contains global_events call_args = pool.fetchval.call_args assert "global_events" in call_args[0][0]