557 lines
20 KiB
Python
557 lines
20 KiB
Python
"""Event classifier module for macro news articles.
|
|
|
|
Classifies global/geopolitical news articles into structured GlobalEvent
|
|
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
|
|
existing OllamaClient for inference and retry logic.
|
|
|
|
Persists classification prompts, raw outputs, and final events to MinIO
|
|
and PostgreSQL for audit and downstream interpolation.
|
|
|
|
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import asyncpg
|
|
from minio import Minio
|
|
|
|
from services.shared.schemas import (
|
|
EstimatedDuration,
|
|
ImpactType,
|
|
ModelMetadata,
|
|
SeverityLevel,
|
|
)
|
|
from services.shared.storage import upload_artifact
|
|
|
|
logger = logging.getLogger("event_classifier")
|
|
|
|
PROMPT_VERSION = "event-classification-v1"
|
|
SCHEMA_VERSION = "1.0.0"
|
|
|
|
# Valid enum value sets for normalization
|
|
_VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType)
|
|
_VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel)
|
|
_VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GlobalEvent dataclass
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class GlobalEvent:
|
|
"""Structured classification of a macro news event.
|
|
|
|
Produced by the event classifier from Ollama structured output.
|
|
"""
|
|
|
|
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
event_types: list[str] = field(default_factory=list)
|
|
severity: str = "low"
|
|
affected_regions: list[str] = field(default_factory=list)
|
|
affected_sectors: list[str] = field(default_factory=list)
|
|
affected_commodities: list[str] = field(default_factory=list)
|
|
summary: str = ""
|
|
key_facts: list[str] = field(default_factory=list)
|
|
estimated_duration: str = "short_term"
|
|
confidence: float = 0.5
|
|
source_document_id: str = ""
|
|
model_metadata: ModelMetadata = field(default_factory=ModelMetadata)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# JSON schema for Ollama structured output
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class _EventClassificationResult:
|
|
"""Schema definition for the Ollama event classification response.
|
|
|
|
Not a Pydantic model — we build the JSON schema dict directly to keep
|
|
it self-contained and Ollama-friendly (no $refs).
|
|
"""
|
|
pass
|
|
|
|
|
|
def get_event_json_schema() -> dict[str, Any]:
|
|
"""Return the JSON schema for Ollama structured event classification output.
|
|
|
|
The schema forces the model to produce all required fields explicitly.
|
|
"""
|
|
return {
|
|
"type": "object",
|
|
"required": [
|
|
"event_types",
|
|
"severity",
|
|
"affected_regions",
|
|
"affected_sectors",
|
|
"affected_commodities",
|
|
"summary",
|
|
"key_facts",
|
|
"estimated_duration",
|
|
"confidence",
|
|
],
|
|
"properties": {
|
|
"event_types": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "string",
|
|
"enum": sorted(_VALID_IMPACT_TYPES),
|
|
},
|
|
"description": (
|
|
"One or more impact types this event represents. "
|
|
"Include ALL applicable types — do not collapse to a single category."
|
|
),
|
|
},
|
|
"severity": {
|
|
"type": "string",
|
|
"enum": sorted(_VALID_SEVERITY_LEVELS),
|
|
"description": "Overall severity of the event: low, moderate, high, or critical.",
|
|
},
|
|
"affected_regions": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": (
|
|
"ISO 3166-1 alpha-2 country codes or region names affected. "
|
|
"Use standard codes like US, CN, EU, GB, JP. "
|
|
"Only include regions explicitly mentioned or clearly implied."
|
|
),
|
|
},
|
|
"affected_sectors": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": (
|
|
"GICS sector identifiers or sector names affected. "
|
|
"Examples: Energy, Materials, Industrials, Consumer Discretionary, "
|
|
"Consumer Staples, Health Care, Financials, Information Technology, "
|
|
"Communication Services, Utilities, Real Estate."
|
|
),
|
|
},
|
|
"affected_commodities": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": (
|
|
"Commodity identifiers affected, if applicable. "
|
|
"Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, "
|
|
"semiconductors. Empty list if no commodities are directly affected."
|
|
),
|
|
},
|
|
"summary": {
|
|
"type": "string",
|
|
"description": "A concise 1-3 sentence summary of the event and its market implications.",
|
|
},
|
|
"key_facts": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": (
|
|
"Key facts explicitly stated in the article. "
|
|
"Do NOT infer, speculate, or fabricate facts. "
|
|
"Each fact must be directly supported by the text."
|
|
),
|
|
},
|
|
"estimated_duration": {
|
|
"type": "string",
|
|
"enum": sorted(_VALID_DURATIONS),
|
|
"description": (
|
|
"Expected duration of market impact: "
|
|
"short_term (days to weeks), medium_term (weeks to months), "
|
|
"long_term (months to years)."
|
|
),
|
|
},
|
|
"confidence": {
|
|
"type": "number",
|
|
"minimum": 0.0,
|
|
"maximum": 1.0,
|
|
"description": (
|
|
"Your confidence in this classification. "
|
|
"Lower if the article is ambiguous, speculative, or lacks concrete details."
|
|
),
|
|
},
|
|
},
|
|
"additionalProperties": False,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prompt builder
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SYSTEM_PROMPT = """\
|
|
You classify global news into structured macro event JSON. \
|
|
Return ONLY a single JSON object. No markdown, no explanation. \
|
|
Every field is required. Keep key_facts to 3-5 items. Keep summary under 3 sentences."""
|
|
|
|
_ANTI_HALLUCINATION_RULES = """\
|
|
RULES:
|
|
- Only extract facts EXPLICITLY stated in the text. Do NOT fabricate.
|
|
- If vague or speculative, set confidence below 0.4.
|
|
- Distinguish announced policy from rumored policy."""
|
|
|
|
|
|
def build_event_classification_prompt(text: str) -> dict[str, str]:
|
|
"""Build system and user prompts for Ollama event classification.
|
|
|
|
Args:
|
|
text: Normalized text content of the macro news article.
|
|
|
|
Returns:
|
|
Dict with 'system' and 'user' prompt strings.
|
|
"""
|
|
# Truncate long articles to reduce inference time
|
|
max_chars = 6000
|
|
if len(text) > max_chars:
|
|
text = text[:max_chars] + "\n[... truncated ...]"
|
|
|
|
user_prompt = f"""\
|
|
Classify this global news article as a macro event. Fill every field.
|
|
|
|
{_ANTI_HALLUCINATION_RULES}
|
|
|
|
Classify the event by:
|
|
- event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \
|
|
regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk)
|
|
- severity: low, moderate, high, or critical
|
|
- affected_regions: ISO country codes or region names
|
|
- affected_sectors: GICS sector names
|
|
- affected_commodities: commodity identifiers (empty list if none)
|
|
- summary: 1-3 sentence summary of the event and market implications
|
|
- key_facts: facts explicitly stated in the text (NO fabrication)
|
|
- estimated_duration: short_term, medium_term, or long_term
|
|
- confidence: 0.0-1.0 your confidence in this classification
|
|
|
|
--- ARTICLE TEXT ---
|
|
{text}
|
|
--- END ARTICLE TEXT ---"""
|
|
|
|
return {
|
|
"system": _SYSTEM_PROMPT,
|
|
"user": user_prompt,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Classification response parsing and normalization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _normalize_event_types(raw: list[Any]) -> list[str]:
|
|
"""Normalize and filter event_types to valid ImpactType values."""
|
|
result = []
|
|
for item in raw:
|
|
val = str(item).lower().strip()
|
|
if val in _VALID_IMPACT_TYPES:
|
|
result.append(val)
|
|
return result if result else ["geopolitical_risk"]
|
|
|
|
|
|
def _normalize_severity(raw: str) -> str:
|
|
"""Normalize severity to a valid SeverityLevel value."""
|
|
val = str(raw).lower().strip()
|
|
return val if val in _VALID_SEVERITY_LEVELS else "low"
|
|
|
|
|
|
def _normalize_duration(raw: str) -> str:
|
|
"""Normalize estimated_duration to a valid EstimatedDuration value."""
|
|
val = str(raw).lower().strip()
|
|
return val if val in _VALID_DURATIONS else "short_term"
|
|
|
|
|
|
def _parse_classification_response(
|
|
raw_json: str,
|
|
document_id: str,
|
|
model_name: str,
|
|
) -> GlobalEvent:
|
|
"""Parse raw Ollama JSON output into a GlobalEvent.
|
|
|
|
Normalizes enum values and clamps numeric fields.
|
|
"""
|
|
data = json.loads(raw_json)
|
|
|
|
confidence = data.get("confidence", 0.5)
|
|
if isinstance(confidence, (int, float)):
|
|
confidence = max(0.0, min(1.0, float(confidence)))
|
|
else:
|
|
confidence = 0.5
|
|
|
|
summary = str(data.get("summary", "")).strip()
|
|
key_facts = [str(f) for f in data.get("key_facts", []) if str(f).strip()]
|
|
|
|
# Reject empty classifications — the LLM produced no useful output
|
|
if not summary and not key_facts:
|
|
raise ValueError(
|
|
f"Empty classification for document {document_id}: "
|
|
"no summary and no key facts"
|
|
)
|
|
|
|
return GlobalEvent(
|
|
event_id=str(uuid.uuid4()),
|
|
event_types=_normalize_event_types(data.get("event_types", [])),
|
|
severity=_normalize_severity(data.get("severity", "low")),
|
|
affected_regions=[str(r) for r in data.get("affected_regions", [])],
|
|
affected_sectors=[str(s) for s in data.get("affected_sectors", [])],
|
|
affected_commodities=[str(c) for c in data.get("affected_commodities", [])],
|
|
summary=summary,
|
|
key_facts=key_facts,
|
|
estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")),
|
|
confidence=confidence,
|
|
source_document_id=document_id,
|
|
model_metadata=ModelMetadata(
|
|
provider="ollama",
|
|
model_name=model_name,
|
|
prompt_version=PROMPT_VERSION,
|
|
schema_version=SCHEMA_VERSION,
|
|
),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MinIO persistence helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _upload_classification_prompt(
|
|
minio_client: Minio,
|
|
document_id: str,
|
|
prompt_data: dict[str, str],
|
|
model_name: str,
|
|
timestamp: datetime | None = None,
|
|
) -> str:
|
|
"""Upload classification prompt and metadata to stonks-llm-prompts."""
|
|
ts = timestamp or datetime.now(timezone.utc)
|
|
payload = json.dumps({
|
|
"prompt_version": PROMPT_VERSION,
|
|
"schema_version": SCHEMA_VERSION,
|
|
"model": model_name,
|
|
"system_prompt": prompt_data["system"],
|
|
"user_prompt": prompt_data["user"],
|
|
"json_schema": get_event_json_schema(),
|
|
}, indent=2).encode()
|
|
|
|
path = (
|
|
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
|
f"{document_id}/prompt.json"
|
|
)
|
|
return upload_artifact(
|
|
minio_client, "stonks-llm-prompts", path, payload,
|
|
content_type="application/json",
|
|
)
|
|
|
|
|
|
def _upload_classification_result(
|
|
minio_client: Minio,
|
|
document_id: str,
|
|
raw_output: str,
|
|
event: GlobalEvent | None,
|
|
success: bool,
|
|
error: str | None,
|
|
timestamp: datetime | None = None,
|
|
) -> str:
|
|
"""Upload raw classification output to stonks-llm-results."""
|
|
ts = timestamp or datetime.now(timezone.utc)
|
|
payload = json.dumps({
|
|
"document_id": document_id,
|
|
"success": success,
|
|
"error": error,
|
|
"raw_output": raw_output,
|
|
"parsed_event": {
|
|
"event_id": event.event_id,
|
|
"event_types": event.event_types,
|
|
"severity": event.severity,
|
|
"affected_regions": event.affected_regions,
|
|
"affected_sectors": event.affected_sectors,
|
|
"affected_commodities": event.affected_commodities,
|
|
"summary": event.summary,
|
|
"key_facts": event.key_facts,
|
|
"estimated_duration": event.estimated_duration,
|
|
"confidence": event.confidence,
|
|
} if event else None,
|
|
}, indent=2).encode()
|
|
|
|
path = (
|
|
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
|
f"{document_id}/result.json"
|
|
)
|
|
return upload_artifact(
|
|
minio_client, "stonks-llm-results", path, payload,
|
|
content_type="application/json",
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PostgreSQL persistence
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def persist_global_event(
|
|
pool: asyncpg.Pool,
|
|
event: GlobalEvent,
|
|
) -> str:
|
|
"""Persist a GlobalEvent record to the global_events PostgreSQL table.
|
|
|
|
Returns the event row UUID.
|
|
"""
|
|
row_id = await pool.fetchval(
|
|
"""INSERT INTO global_events
|
|
(id, event_types, severity, affected_regions, affected_sectors,
|
|
affected_commodities, summary, key_facts, estimated_duration,
|
|
confidence, source_document_id, model_provider, model_name,
|
|
prompt_version, schema_version)
|
|
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10,
|
|
$11::uuid, $12, $13, $14, $15)
|
|
RETURNING id""",
|
|
event.event_id,
|
|
event.event_types,
|
|
event.severity,
|
|
event.affected_regions,
|
|
event.affected_sectors,
|
|
event.affected_commodities,
|
|
event.summary,
|
|
json.dumps(event.key_facts),
|
|
event.estimated_duration,
|
|
event.confidence,
|
|
event.source_document_id,
|
|
event.model_metadata.provider,
|
|
event.model_metadata.model_name,
|
|
event.model_metadata.prompt_version,
|
|
event.model_metadata.schema_version,
|
|
)
|
|
logger.info(
|
|
"Persisted global event %s for doc %s (severity=%s, types=%s)",
|
|
row_id, event.source_document_id, event.severity, event.event_types,
|
|
)
|
|
return str(row_id)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main classification function
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def classify_global_event(
|
|
normalized_text: str,
|
|
document_id: str,
|
|
ollama_client: Any,
|
|
*,
|
|
pool: asyncpg.Pool | None = None,
|
|
minio_client: Minio | None = None,
|
|
) -> GlobalEvent:
|
|
"""Classify a macro news article into a GlobalEvent using Ollama.
|
|
|
|
Uses the existing OllamaClient's streaming infrastructure with a
|
|
dedicated event classification prompt and JSON schema. Follows the
|
|
same retry policy as document extraction.
|
|
|
|
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
|
when the respective clients are provided.
|
|
|
|
Args:
|
|
normalized_text: Cleaned text content of the macro article.
|
|
document_id: UUID of the source document.
|
|
ollama_client: An OllamaClient instance (from services.extractor.client).
|
|
pool: Optional asyncpg pool for PostgreSQL persistence.
|
|
minio_client: Optional MinIO client for artifact persistence.
|
|
|
|
Returns:
|
|
A GlobalEvent with the classification result.
|
|
|
|
Raises:
|
|
ValueError: If classification fails after all retries.
|
|
"""
|
|
ts = datetime.now(timezone.utc)
|
|
prompts = build_event_classification_prompt(normalized_text)
|
|
json_schema = get_event_json_schema()
|
|
model_name = ollama_client._config.model
|
|
|
|
# Persist prompt to MinIO
|
|
if minio_client:
|
|
try:
|
|
_upload_classification_prompt(
|
|
minio_client, document_id, prompts, model_name, timestamp=ts,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
|
|
|
# Call Ollama using the client's internal _call_ollama method
|
|
# We reuse the retry logic pattern from OllamaClient.extract()
|
|
max_retries = ollama_client._max_retries
|
|
last_error: str | None = None
|
|
raw_output = ""
|
|
|
|
for attempt_num in range(max_retries + 1):
|
|
attempt = await ollama_client._call_ollama(prompts, json_schema)
|
|
raw_output = attempt.raw_output
|
|
|
|
if attempt.error is None and raw_output:
|
|
# Try to parse the response
|
|
try:
|
|
event = _parse_classification_response(
|
|
raw_output, document_id, model_name,
|
|
)
|
|
|
|
# Persist result to MinIO
|
|
if minio_client:
|
|
try:
|
|
_upload_classification_result(
|
|
minio_client, document_id, raw_output,
|
|
event, success=True, error=None, timestamp=ts,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to upload classification result for doc %s", document_id,
|
|
)
|
|
|
|
# Persist to PostgreSQL
|
|
if pool:
|
|
try:
|
|
await persist_global_event(pool, event)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to persist global event for doc %s", document_id,
|
|
)
|
|
|
|
return event
|
|
|
|
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
|
last_error = f"parse_error: {exc}"
|
|
logger.warning(
|
|
"Classification parse error for doc %s attempt %d: %s",
|
|
document_id, attempt_num + 1, exc,
|
|
)
|
|
else:
|
|
last_error = attempt.error or "empty_response"
|
|
|
|
# Retry with backoff
|
|
if attempt_num < max_retries:
|
|
delay = ollama_client._base_delay * (
|
|
ollama_client._backoff_multiplier ** attempt_num
|
|
)
|
|
delay = min(delay, ollama_client._max_delay)
|
|
logger.warning(
|
|
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
|
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
# All retries exhausted — persist failure and raise
|
|
if minio_client:
|
|
try:
|
|
_upload_classification_result(
|
|
minio_client, document_id, raw_output,
|
|
event=None, success=False, error=last_error, timestamp=ts,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to upload failed classification result for doc %s", document_id,
|
|
)
|
|
|
|
raise ValueError(
|
|
f"Event classification failed for document {document_id} "
|
|
f"after {max_retries + 1} attempts: {last_error}"
|
|
)
|