3ff910433f
When the LLM returns empty summary and no key facts, raise ValueError so the retry logic kicks in instead of persisting an empty event. Also strip whitespace from summary and filter empty key_facts entries. Cleaned up 17 empty events from the database.
559 lines
20 KiB
Python
559 lines
20 KiB
Python
"""Event classifier module for macro news articles.
|
|
|
|
Classifies global/geopolitical news articles into structured GlobalEvent
|
|
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
|
|
existing OllamaClient for inference and retry logic.
|
|
|
|
Persists classification prompts, raw outputs, and final events to MinIO
|
|
and PostgreSQL for audit and downstream interpolation.
|
|
|
|
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import asyncpg
|
|
from minio import Minio
|
|
|
|
from services.shared.schemas import (
|
|
EstimatedDuration,
|
|
ImpactType,
|
|
ModelMetadata,
|
|
SeverityLevel,
|
|
)
|
|
from services.shared.storage import upload_artifact
|
|
|
|
logger = logging.getLogger("event_classifier")
|
|
|
|
PROMPT_VERSION = "event-classification-v1"
|
|
SCHEMA_VERSION = "1.0.0"
|
|
|
|
# Valid enum value sets for normalization
|
|
_VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType)
|
|
_VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel)
|
|
_VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GlobalEvent dataclass
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class GlobalEvent:
|
|
"""Structured classification of a macro news event.
|
|
|
|
Produced by the event classifier from Ollama structured output.
|
|
"""
|
|
|
|
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
event_types: list[str] = field(default_factory=list)
|
|
severity: str = "low"
|
|
affected_regions: list[str] = field(default_factory=list)
|
|
affected_sectors: list[str] = field(default_factory=list)
|
|
affected_commodities: list[str] = field(default_factory=list)
|
|
summary: str = ""
|
|
key_facts: list[str] = field(default_factory=list)
|
|
estimated_duration: str = "short_term"
|
|
confidence: float = 0.5
|
|
source_document_id: str = ""
|
|
model_metadata: ModelMetadata = field(default_factory=ModelMetadata)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# JSON schema for Ollama structured output
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class _EventClassificationResult:
|
|
"""Schema definition for the Ollama event classification response.
|
|
|
|
Not a Pydantic model — we build the JSON schema dict directly to keep
|
|
it self-contained and Ollama-friendly (no $refs).
|
|
"""
|
|
pass
|
|
|
|
|
|
def get_event_json_schema() -> dict[str, Any]:
|
|
"""Return the JSON schema for Ollama structured event classification output.
|
|
|
|
The schema forces the model to produce all required fields explicitly.
|
|
"""
|
|
return {
|
|
"type": "object",
|
|
"required": [
|
|
"event_types",
|
|
"severity",
|
|
"affected_regions",
|
|
"affected_sectors",
|
|
"affected_commodities",
|
|
"summary",
|
|
"key_facts",
|
|
"estimated_duration",
|
|
"confidence",
|
|
],
|
|
"properties": {
|
|
"event_types": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "string",
|
|
"enum": sorted(_VALID_IMPACT_TYPES),
|
|
},
|
|
"description": (
|
|
"One or more impact types this event represents. "
|
|
"Include ALL applicable types — do not collapse to a single category."
|
|
),
|
|
},
|
|
"severity": {
|
|
"type": "string",
|
|
"enum": sorted(_VALID_SEVERITY_LEVELS),
|
|
"description": "Overall severity of the event: low, moderate, high, or critical.",
|
|
},
|
|
"affected_regions": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": (
|
|
"ISO 3166-1 alpha-2 country codes or region names affected. "
|
|
"Use standard codes like US, CN, EU, GB, JP. "
|
|
"Only include regions explicitly mentioned or clearly implied."
|
|
),
|
|
},
|
|
"affected_sectors": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": (
|
|
"GICS sector identifiers or sector names affected. "
|
|
"Examples: Energy, Materials, Industrials, Consumer Discretionary, "
|
|
"Consumer Staples, Health Care, Financials, Information Technology, "
|
|
"Communication Services, Utilities, Real Estate."
|
|
),
|
|
},
|
|
"affected_commodities": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": (
|
|
"Commodity identifiers affected, if applicable. "
|
|
"Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, "
|
|
"semiconductors. Empty list if no commodities are directly affected."
|
|
),
|
|
},
|
|
"summary": {
|
|
"type": "string",
|
|
"description": "A concise 1-3 sentence summary of the event and its market implications.",
|
|
},
|
|
"key_facts": {
|
|
"type": "array",
|
|
"items": {"type": "string"},
|
|
"description": (
|
|
"Key facts explicitly stated in the article. "
|
|
"Do NOT infer, speculate, or fabricate facts. "
|
|
"Each fact must be directly supported by the text."
|
|
),
|
|
},
|
|
"estimated_duration": {
|
|
"type": "string",
|
|
"enum": sorted(_VALID_DURATIONS),
|
|
"description": (
|
|
"Expected duration of market impact: "
|
|
"short_term (days to weeks), medium_term (weeks to months), "
|
|
"long_term (months to years)."
|
|
),
|
|
},
|
|
"confidence": {
|
|
"type": "number",
|
|
"minimum": 0.0,
|
|
"maximum": 1.0,
|
|
"description": (
|
|
"Your confidence in this classification. "
|
|
"Lower if the article is ambiguous, speculative, or lacks concrete details."
|
|
),
|
|
},
|
|
},
|
|
"additionalProperties": False,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prompt builder
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SYSTEM_PROMPT = """\
|
|
You classify global news articles into structured macro event intelligence. \
|
|
Read the article carefully and extract the event classification. \
|
|
Return ONLY valid JSON matching the schema. No commentary, no markdown, no explanation."""
|
|
|
|
_ANTI_HALLUCINATION_RULES = """\
|
|
CRITICAL RULES — read carefully:
|
|
1. Only extract information EXPLICITLY stated in the article text.
|
|
2. Do NOT infer, speculate, or fabricate facts, regions, sectors, or commodities.
|
|
3. If the article mentions multiple distinct impact types, include ALL of them in event_types.
|
|
4. For affected_regions, only include regions explicitly mentioned or clearly implied by the event.
|
|
5. For affected_sectors, only include sectors with a clear causal link to the event.
|
|
6. For affected_commodities, only include commodities directly referenced or obviously impacted.
|
|
7. For key_facts, each fact must be directly supported by a specific passage in the text.
|
|
8. If the article is vague or speculative, set confidence LOW (below 0.4).
|
|
9. Do NOT treat journalist speculation or opinion as confirmed fact.
|
|
10. Distinguish between announced policy and proposed/rumored policy."""
|
|
|
|
|
|
def build_event_classification_prompt(text: str) -> dict[str, str]:
|
|
"""Build system and user prompts for Ollama event classification.
|
|
|
|
Args:
|
|
text: Normalized text content of the macro news article.
|
|
|
|
Returns:
|
|
Dict with 'system' and 'user' prompt strings.
|
|
"""
|
|
user_prompt = f"""\
|
|
Classify this global news article as a macro event. Fill every field.
|
|
|
|
{_ANTI_HALLUCINATION_RULES}
|
|
|
|
Classify the event by:
|
|
- event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \
|
|
regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk)
|
|
- severity: low, moderate, high, or critical
|
|
- affected_regions: ISO country codes or region names
|
|
- affected_sectors: GICS sector names
|
|
- affected_commodities: commodity identifiers (empty list if none)
|
|
- summary: 1-3 sentence summary of the event and market implications
|
|
- key_facts: facts explicitly stated in the text (NO fabrication)
|
|
- estimated_duration: short_term, medium_term, or long_term
|
|
- confidence: 0.0-1.0 your confidence in this classification
|
|
|
|
--- ARTICLE TEXT ---
|
|
{text}
|
|
--- END ARTICLE TEXT ---"""
|
|
|
|
return {
|
|
"system": _SYSTEM_PROMPT,
|
|
"user": user_prompt,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Classification response parsing and normalization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _normalize_event_types(raw: list[Any]) -> list[str]:
|
|
"""Normalize and filter event_types to valid ImpactType values."""
|
|
result = []
|
|
for item in raw:
|
|
val = str(item).lower().strip()
|
|
if val in _VALID_IMPACT_TYPES:
|
|
result.append(val)
|
|
return result if result else ["geopolitical_risk"]
|
|
|
|
|
|
def _normalize_severity(raw: str) -> str:
|
|
"""Normalize severity to a valid SeverityLevel value."""
|
|
val = str(raw).lower().strip()
|
|
return val if val in _VALID_SEVERITY_LEVELS else "low"
|
|
|
|
|
|
def _normalize_duration(raw: str) -> str:
|
|
"""Normalize estimated_duration to a valid EstimatedDuration value."""
|
|
val = str(raw).lower().strip()
|
|
return val if val in _VALID_DURATIONS else "short_term"
|
|
|
|
|
|
def _parse_classification_response(
|
|
raw_json: str,
|
|
document_id: str,
|
|
model_name: str,
|
|
) -> GlobalEvent:
|
|
"""Parse raw Ollama JSON output into a GlobalEvent.
|
|
|
|
Normalizes enum values and clamps numeric fields.
|
|
"""
|
|
data = json.loads(raw_json)
|
|
|
|
confidence = data.get("confidence", 0.5)
|
|
if isinstance(confidence, (int, float)):
|
|
confidence = max(0.0, min(1.0, float(confidence)))
|
|
else:
|
|
confidence = 0.5
|
|
|
|
summary = str(data.get("summary", "")).strip()
|
|
key_facts = [str(f) for f in data.get("key_facts", []) if str(f).strip()]
|
|
|
|
# Reject empty classifications — the LLM produced no useful output
|
|
if not summary and not key_facts:
|
|
raise ValueError(
|
|
f"Empty classification for document {document_id}: "
|
|
"no summary and no key facts"
|
|
)
|
|
|
|
return GlobalEvent(
|
|
event_id=str(uuid.uuid4()),
|
|
event_types=_normalize_event_types(data.get("event_types", [])),
|
|
severity=_normalize_severity(data.get("severity", "low")),
|
|
affected_regions=[str(r) for r in data.get("affected_regions", [])],
|
|
affected_sectors=[str(s) for s in data.get("affected_sectors", [])],
|
|
affected_commodities=[str(c) for c in data.get("affected_commodities", [])],
|
|
summary=summary,
|
|
key_facts=key_facts,
|
|
estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")),
|
|
confidence=confidence,
|
|
source_document_id=document_id,
|
|
model_metadata=ModelMetadata(
|
|
provider="ollama",
|
|
model_name=model_name,
|
|
prompt_version=PROMPT_VERSION,
|
|
schema_version=SCHEMA_VERSION,
|
|
),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MinIO persistence helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _upload_classification_prompt(
|
|
minio_client: Minio,
|
|
document_id: str,
|
|
prompt_data: dict[str, str],
|
|
model_name: str,
|
|
timestamp: datetime | None = None,
|
|
) -> str:
|
|
"""Upload classification prompt and metadata to stonks-llm-prompts."""
|
|
ts = timestamp or datetime.now(timezone.utc)
|
|
payload = json.dumps({
|
|
"prompt_version": PROMPT_VERSION,
|
|
"schema_version": SCHEMA_VERSION,
|
|
"model": model_name,
|
|
"system_prompt": prompt_data["system"],
|
|
"user_prompt": prompt_data["user"],
|
|
"json_schema": get_event_json_schema(),
|
|
}, indent=2).encode()
|
|
|
|
path = (
|
|
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
|
f"{document_id}/prompt.json"
|
|
)
|
|
return upload_artifact(
|
|
minio_client, "stonks-llm-prompts", path, payload,
|
|
content_type="application/json",
|
|
)
|
|
|
|
|
|
def _upload_classification_result(
|
|
minio_client: Minio,
|
|
document_id: str,
|
|
raw_output: str,
|
|
event: GlobalEvent | None,
|
|
success: bool,
|
|
error: str | None,
|
|
timestamp: datetime | None = None,
|
|
) -> str:
|
|
"""Upload raw classification output to stonks-llm-results."""
|
|
ts = timestamp or datetime.now(timezone.utc)
|
|
payload = json.dumps({
|
|
"document_id": document_id,
|
|
"success": success,
|
|
"error": error,
|
|
"raw_output": raw_output,
|
|
"parsed_event": {
|
|
"event_id": event.event_id,
|
|
"event_types": event.event_types,
|
|
"severity": event.severity,
|
|
"affected_regions": event.affected_regions,
|
|
"affected_sectors": event.affected_sectors,
|
|
"affected_commodities": event.affected_commodities,
|
|
"summary": event.summary,
|
|
"key_facts": event.key_facts,
|
|
"estimated_duration": event.estimated_duration,
|
|
"confidence": event.confidence,
|
|
} if event else None,
|
|
}, indent=2).encode()
|
|
|
|
path = (
|
|
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
|
f"{document_id}/result.json"
|
|
)
|
|
return upload_artifact(
|
|
minio_client, "stonks-llm-results", path, payload,
|
|
content_type="application/json",
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PostgreSQL persistence
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def persist_global_event(
|
|
pool: asyncpg.Pool,
|
|
event: GlobalEvent,
|
|
) -> str:
|
|
"""Persist a GlobalEvent record to the global_events PostgreSQL table.
|
|
|
|
Returns the event row UUID.
|
|
"""
|
|
row_id = await pool.fetchval(
|
|
"""INSERT INTO global_events
|
|
(id, event_types, severity, affected_regions, affected_sectors,
|
|
affected_commodities, summary, key_facts, estimated_duration,
|
|
confidence, source_document_id, model_provider, model_name,
|
|
prompt_version, schema_version)
|
|
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10,
|
|
$11::uuid, $12, $13, $14, $15)
|
|
RETURNING id""",
|
|
event.event_id,
|
|
event.event_types,
|
|
event.severity,
|
|
event.affected_regions,
|
|
event.affected_sectors,
|
|
event.affected_commodities,
|
|
event.summary,
|
|
json.dumps(event.key_facts),
|
|
event.estimated_duration,
|
|
event.confidence,
|
|
event.source_document_id,
|
|
event.model_metadata.provider,
|
|
event.model_metadata.model_name,
|
|
event.model_metadata.prompt_version,
|
|
event.model_metadata.schema_version,
|
|
)
|
|
logger.info(
|
|
"Persisted global event %s for doc %s (severity=%s, types=%s)",
|
|
row_id, event.source_document_id, event.severity, event.event_types,
|
|
)
|
|
return str(row_id)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main classification function
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def classify_global_event(
|
|
normalized_text: str,
|
|
document_id: str,
|
|
ollama_client: Any,
|
|
*,
|
|
pool: asyncpg.Pool | None = None,
|
|
minio_client: Minio | None = None,
|
|
) -> GlobalEvent:
|
|
"""Classify a macro news article into a GlobalEvent using Ollama.
|
|
|
|
Uses the existing OllamaClient's streaming infrastructure with a
|
|
dedicated event classification prompt and JSON schema. Follows the
|
|
same retry policy as document extraction.
|
|
|
|
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
|
when the respective clients are provided.
|
|
|
|
Args:
|
|
normalized_text: Cleaned text content of the macro article.
|
|
document_id: UUID of the source document.
|
|
ollama_client: An OllamaClient instance (from services.extractor.client).
|
|
pool: Optional asyncpg pool for PostgreSQL persistence.
|
|
minio_client: Optional MinIO client for artifact persistence.
|
|
|
|
Returns:
|
|
A GlobalEvent with the classification result.
|
|
|
|
Raises:
|
|
ValueError: If classification fails after all retries.
|
|
"""
|
|
ts = datetime.now(timezone.utc)
|
|
prompts = build_event_classification_prompt(normalized_text)
|
|
json_schema = get_event_json_schema()
|
|
model_name = ollama_client._config.model
|
|
|
|
# Persist prompt to MinIO
|
|
if minio_client:
|
|
try:
|
|
_upload_classification_prompt(
|
|
minio_client, document_id, prompts, model_name, timestamp=ts,
|
|
)
|
|
except Exception:
|
|
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
|
|
|
# Call Ollama using the client's internal _call_ollama method
|
|
# We reuse the retry logic pattern from OllamaClient.extract()
|
|
max_retries = ollama_client._max_retries
|
|
last_error: str | None = None
|
|
raw_output = ""
|
|
|
|
for attempt_num in range(max_retries + 1):
|
|
attempt = await ollama_client._call_ollama(prompts, json_schema)
|
|
raw_output = attempt.raw_output
|
|
|
|
if attempt.error is None and raw_output:
|
|
# Try to parse the response
|
|
try:
|
|
event = _parse_classification_response(
|
|
raw_output, document_id, model_name,
|
|
)
|
|
|
|
# Persist result to MinIO
|
|
if minio_client:
|
|
try:
|
|
_upload_classification_result(
|
|
minio_client, document_id, raw_output,
|
|
event, success=True, error=None, timestamp=ts,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to upload classification result for doc %s", document_id,
|
|
)
|
|
|
|
# Persist to PostgreSQL
|
|
if pool:
|
|
try:
|
|
await persist_global_event(pool, event)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to persist global event for doc %s", document_id,
|
|
)
|
|
|
|
return event
|
|
|
|
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
|
last_error = f"parse_error: {exc}"
|
|
logger.warning(
|
|
"Classification parse error for doc %s attempt %d: %s",
|
|
document_id, attempt_num + 1, exc,
|
|
)
|
|
else:
|
|
last_error = attempt.error or "empty_response"
|
|
|
|
# Retry with backoff
|
|
if attempt_num < max_retries:
|
|
delay = ollama_client._base_delay * (
|
|
ollama_client._backoff_multiplier ** attempt_num
|
|
)
|
|
delay = min(delay, ollama_client._max_delay)
|
|
logger.warning(
|
|
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
|
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
# All retries exhausted — persist failure and raise
|
|
if minio_client:
|
|
try:
|
|
_upload_classification_result(
|
|
minio_client, document_id, raw_output,
|
|
event=None, success=False, error=last_error, timestamp=ts,
|
|
)
|
|
except Exception:
|
|
logger.exception(
|
|
"Failed to upload failed classification result for doc %s", document_id,
|
|
)
|
|
|
|
raise ValueError(
|
|
f"Event classification failed for document {document_id} "
|
|
f"after {max_retries + 1} attempts: {last_error}"
|
|
)
|