"""Event classifier module for macro news articles. Classifies global/geopolitical news articles into structured GlobalEvent objects using Ollama with a dedicated prompt and JSON schema. Reuses the existing OllamaClient for inference and retry logic. Persists classification prompts, raw outputs, and final events to MinIO and PostgreSQL for audit and downstream interpolation. Requirements: 2.1, 2.2, 2.3, 2.4, 2.5 """ from __future__ import annotations import asyncio import json import logging import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any import asyncpg from minio import Minio from services.shared.schemas import ( EstimatedDuration, ImpactType, ModelMetadata, SeverityLevel, ) from services.shared.storage import upload_artifact logger = logging.getLogger("event_classifier") PROMPT_VERSION = "event-classification-v1" SCHEMA_VERSION = "1.0.0" # Valid enum value sets for normalization _VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType) _VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel) _VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration) # --------------------------------------------------------------------------- # GlobalEvent dataclass # --------------------------------------------------------------------------- @dataclass class GlobalEvent: """Structured classification of a macro news event. Produced by the event classifier from Ollama structured output. """ event_id: str = field(default_factory=lambda: str(uuid.uuid4())) event_types: list[str] = field(default_factory=list) severity: str = "low" affected_regions: list[str] = field(default_factory=list) affected_sectors: list[str] = field(default_factory=list) affected_commodities: list[str] = field(default_factory=list) summary: str = "" key_facts: list[str] = field(default_factory=list) estimated_duration: str = "short_term" confidence: float = 0.5 source_document_id: str = "" model_metadata: ModelMetadata = field(default_factory=ModelMetadata) # --------------------------------------------------------------------------- # JSON schema for Ollama structured output # --------------------------------------------------------------------------- class _EventClassificationResult: """Schema definition for the Ollama event classification response. Not a Pydantic model — we build the JSON schema dict directly to keep it self-contained and Ollama-friendly (no $refs). """ pass def get_event_json_schema() -> dict[str, Any]: """Return the JSON schema for Ollama structured event classification output. The schema forces the model to produce all required fields explicitly. """ return { "type": "object", "required": [ "event_types", "severity", "affected_regions", "affected_sectors", "affected_commodities", "summary", "key_facts", "estimated_duration", "confidence", ], "properties": { "event_types": { "type": "array", "items": { "type": "string", "enum": sorted(_VALID_IMPACT_TYPES), }, "description": ( "One or more impact types this event represents. " "Include ALL applicable types — do not collapse to a single category." ), }, "severity": { "type": "string", "enum": sorted(_VALID_SEVERITY_LEVELS), "description": "Overall severity of the event: low, moderate, high, or critical.", }, "affected_regions": { "type": "array", "items": {"type": "string"}, "description": ( "ISO 3166-1 alpha-2 country codes or region names affected. " "Use standard codes like US, CN, EU, GB, JP. " "Only include regions explicitly mentioned or clearly implied." ), }, "affected_sectors": { "type": "array", "items": {"type": "string"}, "description": ( "GICS sector identifiers or sector names affected. " "Examples: Energy, Materials, Industrials, Consumer Discretionary, " "Consumer Staples, Health Care, Financials, Information Technology, " "Communication Services, Utilities, Real Estate." ), }, "affected_commodities": { "type": "array", "items": {"type": "string"}, "description": ( "Commodity identifiers affected, if applicable. " "Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, " "semiconductors. Empty list if no commodities are directly affected." ), }, "summary": { "type": "string", "description": "A concise 1-3 sentence summary of the event and its market implications.", }, "key_facts": { "type": "array", "items": {"type": "string"}, "description": ( "Key facts explicitly stated in the article. " "Do NOT infer, speculate, or fabricate facts. " "Each fact must be directly supported by the text." ), }, "estimated_duration": { "type": "string", "enum": sorted(_VALID_DURATIONS), "description": ( "Expected duration of market impact: " "short_term (days to weeks), medium_term (weeks to months), " "long_term (months to years)." ), }, "confidence": { "type": "number", "minimum": 0.0, "maximum": 1.0, "description": ( "Your confidence in this classification. " "Lower if the article is ambiguous, speculative, or lacks concrete details." ), }, }, "additionalProperties": False, } # --------------------------------------------------------------------------- # Prompt builder # --------------------------------------------------------------------------- _SYSTEM_PROMPT = """\ You classify global news articles into structured macro event intelligence. \ Read the article carefully and extract the event classification. \ Return ONLY valid JSON matching the schema. No commentary, no markdown, no explanation.""" _ANTI_HALLUCINATION_RULES = """\ CRITICAL RULES — read carefully: 1. Only extract information EXPLICITLY stated in the article text. 2. Do NOT infer, speculate, or fabricate facts, regions, sectors, or commodities. 3. If the article mentions multiple distinct impact types, include ALL of them in event_types. 4. For affected_regions, only include regions explicitly mentioned or clearly implied by the event. 5. For affected_sectors, only include sectors with a clear causal link to the event. 6. For affected_commodities, only include commodities directly referenced or obviously impacted. 7. For key_facts, each fact must be directly supported by a specific passage in the text. 8. If the article is vague or speculative, set confidence LOW (below 0.4). 9. Do NOT treat journalist speculation or opinion as confirmed fact. 10. Distinguish between announced policy and proposed/rumored policy.""" def build_event_classification_prompt(text: str) -> dict[str, str]: """Build system and user prompts for Ollama event classification. Args: text: Normalized text content of the macro news article. Returns: Dict with 'system' and 'user' prompt strings. """ user_prompt = f"""\ Classify this global news article as a macro event. Fill every field. {_ANTI_HALLUCINATION_RULES} Classify the event by: - event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \ regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk) - severity: low, moderate, high, or critical - affected_regions: ISO country codes or region names - affected_sectors: GICS sector names - affected_commodities: commodity identifiers (empty list if none) - summary: 1-3 sentence summary of the event and market implications - key_facts: facts explicitly stated in the text (NO fabrication) - estimated_duration: short_term, medium_term, or long_term - confidence: 0.0-1.0 your confidence in this classification --- ARTICLE TEXT --- {text} --- END ARTICLE TEXT ---""" return { "system": _SYSTEM_PROMPT, "user": user_prompt, } # --------------------------------------------------------------------------- # Classification response parsing and normalization # --------------------------------------------------------------------------- def _normalize_event_types(raw: list[Any]) -> list[str]: """Normalize and filter event_types to valid ImpactType values.""" result = [] for item in raw: val = str(item).lower().strip() if val in _VALID_IMPACT_TYPES: result.append(val) return result if result else ["geopolitical_risk"] def _normalize_severity(raw: str) -> str: """Normalize severity to a valid SeverityLevel value.""" val = str(raw).lower().strip() return val if val in _VALID_SEVERITY_LEVELS else "low" def _normalize_duration(raw: str) -> str: """Normalize estimated_duration to a valid EstimatedDuration value.""" val = str(raw).lower().strip() return val if val in _VALID_DURATIONS else "short_term" def _parse_classification_response( raw_json: str, document_id: str, model_name: str, ) -> GlobalEvent: """Parse raw Ollama JSON output into a GlobalEvent. Normalizes enum values and clamps numeric fields. """ data = json.loads(raw_json) confidence = data.get("confidence", 0.5) if isinstance(confidence, (int, float)): confidence = max(0.0, min(1.0, float(confidence))) else: confidence = 0.5 return GlobalEvent( event_id=str(uuid.uuid4()), event_types=_normalize_event_types(data.get("event_types", [])), severity=_normalize_severity(data.get("severity", "low")), affected_regions=[str(r) for r in data.get("affected_regions", [])], affected_sectors=[str(s) for s in data.get("affected_sectors", [])], affected_commodities=[str(c) for c in data.get("affected_commodities", [])], summary=str(data.get("summary", "")), key_facts=[str(f) for f in data.get("key_facts", [])], estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")), confidence=confidence, source_document_id=document_id, model_metadata=ModelMetadata( provider="ollama", model_name=model_name, prompt_version=PROMPT_VERSION, schema_version=SCHEMA_VERSION, ), ) # --------------------------------------------------------------------------- # MinIO persistence helpers # --------------------------------------------------------------------------- def _upload_classification_prompt( minio_client: Minio, document_id: str, prompt_data: dict[str, str], model_name: str, timestamp: datetime | None = None, ) -> str: """Upload classification prompt and metadata to stonks-llm-prompts.""" ts = timestamp or datetime.now(timezone.utc) payload = json.dumps({ "prompt_version": PROMPT_VERSION, "schema_version": SCHEMA_VERSION, "model": model_name, "system_prompt": prompt_data["system"], "user_prompt": prompt_data["user"], "json_schema": get_event_json_schema(), }, indent=2).encode() path = ( f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/" f"{document_id}/prompt.json" ) return upload_artifact( minio_client, "stonks-llm-prompts", path, payload, content_type="application/json", ) def _upload_classification_result( minio_client: Minio, document_id: str, raw_output: str, event: GlobalEvent | None, success: bool, error: str | None, timestamp: datetime | None = None, ) -> str: """Upload raw classification output to stonks-llm-results.""" ts = timestamp or datetime.now(timezone.utc) payload = json.dumps({ "document_id": document_id, "success": success, "error": error, "raw_output": raw_output, "parsed_event": { "event_id": event.event_id, "event_types": event.event_types, "severity": event.severity, "affected_regions": event.affected_regions, "affected_sectors": event.affected_sectors, "affected_commodities": event.affected_commodities, "summary": event.summary, "key_facts": event.key_facts, "estimated_duration": event.estimated_duration, "confidence": event.confidence, } if event else None, }, indent=2).encode() path = ( f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/" f"{document_id}/result.json" ) return upload_artifact( minio_client, "stonks-llm-results", path, payload, content_type="application/json", ) # --------------------------------------------------------------------------- # PostgreSQL persistence # --------------------------------------------------------------------------- async def persist_global_event( pool: asyncpg.Pool, event: GlobalEvent, ) -> str: """Persist a GlobalEvent record to the global_events PostgreSQL table. Returns the event row UUID. """ row_id = await pool.fetchval( """INSERT INTO global_events (id, event_types, severity, affected_regions, affected_sectors, affected_commodities, summary, key_facts, estimated_duration, confidence, source_document_id, model_provider, model_name, prompt_version, schema_version) VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10, $11::uuid, $12, $13, $14, $15) RETURNING id""", event.event_id, event.event_types, event.severity, event.affected_regions, event.affected_sectors, event.affected_commodities, event.summary, json.dumps(event.key_facts), event.estimated_duration, event.confidence, event.source_document_id, event.model_metadata.provider, event.model_metadata.model_name, event.model_metadata.prompt_version, event.model_metadata.schema_version, ) logger.info( "Persisted global event %s for doc %s (severity=%s, types=%s)", row_id, event.source_document_id, event.severity, event.event_types, ) return str(row_id) # --------------------------------------------------------------------------- # Main classification function # --------------------------------------------------------------------------- async def classify_global_event( normalized_text: str, document_id: str, ollama_client: Any, *, pool: asyncpg.Pool | None = None, minio_client: Minio | None = None, ) -> GlobalEvent: """Classify a macro news article into a GlobalEvent using Ollama. Uses the existing OllamaClient's streaming infrastructure with a dedicated event classification prompt and JSON schema. Follows the same retry policy as document extraction. Persists prompt, raw output, and final event to MinIO and PostgreSQL when the respective clients are provided. Args: normalized_text: Cleaned text content of the macro article. document_id: UUID of the source document. ollama_client: An OllamaClient instance (from services.extractor.client). pool: Optional asyncpg pool for PostgreSQL persistence. minio_client: Optional MinIO client for artifact persistence. Returns: A GlobalEvent with the classification result. Raises: ValueError: If classification fails after all retries. """ ts = datetime.now(timezone.utc) prompts = build_event_classification_prompt(normalized_text) json_schema = get_event_json_schema() model_name = ollama_client._config.model # Persist prompt to MinIO if minio_client: try: _upload_classification_prompt( minio_client, document_id, prompts, model_name, timestamp=ts, ) except Exception: logger.exception("Failed to upload classification prompt for doc %s", document_id) # Call Ollama using the client's internal _call_ollama method # We reuse the retry logic pattern from OllamaClient.extract() max_retries = ollama_client._max_retries last_error: str | None = None raw_output = "" for attempt_num in range(max_retries + 1): attempt = await ollama_client._call_ollama(prompts, json_schema) raw_output = attempt.raw_output if attempt.error is None and raw_output: # Try to parse the response try: event = _parse_classification_response( raw_output, document_id, model_name, ) # Persist result to MinIO if minio_client: try: _upload_classification_result( minio_client, document_id, raw_output, event, success=True, error=None, timestamp=ts, ) except Exception: logger.exception( "Failed to upload classification result for doc %s", document_id, ) # Persist to PostgreSQL if pool: try: await persist_global_event(pool, event) except Exception: logger.exception( "Failed to persist global event for doc %s", document_id, ) return event except (json.JSONDecodeError, KeyError, TypeError) as exc: last_error = f"parse_error: {exc}" logger.warning( "Classification parse error for doc %s attempt %d: %s", document_id, attempt_num + 1, exc, ) else: last_error = attempt.error or "empty_response" # Retry with backoff if attempt_num < max_retries: delay = ollama_client._base_delay * ( ollama_client._backoff_multiplier ** attempt_num ) delay = min(delay, ollama_client._max_delay) logger.warning( "Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs", attempt_num + 1, max_retries + 1, document_id, last_error, delay, ) await asyncio.sleep(delay) # All retries exhausted — persist failure and raise if minio_client: try: _upload_classification_result( minio_client, document_id, raw_output, event=None, success=False, error=last_error, timestamp=ts, ) except Exception: logger.exception( "Failed to upload failed classification result for doc %s", document_id, ) raise ValueError( f"Event classification failed for document {document_id} " f"after {max_retries + 1} attempts: {last_error}" )