Files
stonks-oracle/services/extractor/event_classifier.py
T

705 lines
26 KiB
Python

"""Event classifier module for macro news articles.
Classifies global/geopolitical news articles into structured GlobalEvent
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
existing OllamaClient for inference and retry logic.
Persists classification prompts, raw outputs, and final events to MinIO
and PostgreSQL for audit and downstream interpolation.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
import asyncpg
from minio import Minio
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
from services.shared.schemas import (
EstimatedDuration,
ImpactType,
ModelMetadata,
SeverityLevel,
)
from services.shared.storage import _prefixed, upload_artifact
logger = logging.getLogger("event_classifier")
PROMPT_VERSION = "event-classification-v1"
SCHEMA_VERSION = "1.0.0"
# Valid enum value sets for normalization
_VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType)
_VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel)
_VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration)
# ---------------------------------------------------------------------------
# GlobalEvent dataclass
# ---------------------------------------------------------------------------
@dataclass
class GlobalEvent:
"""Structured classification of a macro news event.
Produced by the event classifier from Ollama structured output.
"""
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
event_types: list[str] = field(default_factory=list)
severity: str = "low"
affected_regions: list[str] = field(default_factory=list)
affected_sectors: list[str] = field(default_factory=list)
affected_commodities: list[str] = field(default_factory=list)
summary: str = ""
key_facts: list[str] = field(default_factory=list)
estimated_duration: str = "short_term"
confidence: float = 0.5
source_document_id: str = ""
model_metadata: ModelMetadata = field(default_factory=ModelMetadata)
# ---------------------------------------------------------------------------
# JSON schema for Ollama structured output
# ---------------------------------------------------------------------------
class _EventClassificationResult:
"""Schema definition for the Ollama event classification response.
Not a Pydantic model — we build the JSON schema dict directly to keep
it self-contained and Ollama-friendly (no $refs).
"""
pass
def get_event_json_schema() -> dict[str, Any]:
"""Return the JSON schema for Ollama structured event classification output.
The schema forces the model to produce all required fields explicitly.
"""
return {
"type": "object",
"required": [
"event_types",
"severity",
"affected_regions",
"affected_sectors",
"affected_commodities",
"summary",
"key_facts",
"estimated_duration",
"confidence",
],
"properties": {
"event_types": {
"type": "array",
"items": {
"type": "string",
"enum": sorted(_VALID_IMPACT_TYPES),
},
"description": (
"One or more impact types this event represents. "
"Include ALL applicable types — do not collapse to a single category."
),
},
"severity": {
"type": "string",
"enum": sorted(_VALID_SEVERITY_LEVELS),
"description": "Overall severity of the event: low, moderate, high, or critical.",
},
"affected_regions": {
"type": "array",
"items": {"type": "string"},
"description": (
"ISO 3166-1 alpha-2 country codes or region names affected. "
"Use standard codes like US, CN, EU, GB, JP. "
"Only include regions explicitly mentioned or clearly implied."
),
},
"affected_sectors": {
"type": "array",
"items": {"type": "string"},
"description": (
"GICS sector identifiers or sector names affected. "
"Examples: Energy, Materials, Industrials, Consumer Discretionary, "
"Consumer Staples, Health Care, Financials, Information Technology, "
"Communication Services, Utilities, Real Estate."
),
},
"affected_commodities": {
"type": "array",
"items": {"type": "string"},
"description": (
"Commodity identifiers affected, if applicable. "
"Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, "
"semiconductors. Empty list if no commodities are directly affected."
),
},
"summary": {
"type": "string",
"description": "A concise 1-3 sentence summary of the event and its market implications.",
},
"key_facts": {
"type": "array",
"items": {"type": "string"},
"description": (
"Key facts explicitly stated in the article. "
"Do NOT infer, speculate, or fabricate facts. "
"Each fact must be directly supported by the text."
),
},
"estimated_duration": {
"type": "string",
"enum": sorted(_VALID_DURATIONS),
"description": (
"Expected duration of market impact: "
"short_term (days to weeks), medium_term (weeks to months), "
"long_term (months to years)."
),
},
"confidence": {
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
"description": (
"Your confidence in this classification. "
"Lower if the article is ambiguous, speculative, or lacks concrete details."
),
},
},
"additionalProperties": False,
}
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
_SYSTEM_PROMPT = """\
You classify MACRO-LEVEL global news into structured event JSON. \
Return ONLY a single JSON object. No markdown, no explanation. \
Every field is required. Keep key_facts to 3-5 items. Keep summary under 3 sentences.
CRITICAL: Only classify articles about MACRO events that affect entire markets, \
sectors, or economies. Examples: trade wars, interest rate changes, commodity \
supply disruptions, regulatory changes, geopolitical conflicts, natural disasters.
DO NOT classify as macro events: individual company earnings, lawsuits against \
a single company, single-company management changes, individual stock analysis, \
company-specific debt or bankruptcy, product launches by one company. \
For these, set severity to "low", confidence below 0.3, and leave \
affected_regions, affected_sectors, and affected_commodities as empty arrays."""
_ANTI_HALLUCINATION_RULES = """\
RULES:
- Only extract facts EXPLICITLY stated in the text. Do NOT fabricate.
- If vague or speculative, set confidence below 0.4.
- Distinguish announced policy from rumored policy.
- If the article is about a SINGLE COMPANY (not a sector or market), set severity to "low" and confidence below 0.3.
- Only tag event_types that are DIRECTLY described in the article. Do NOT infer secondary effects.
- severity "critical" is reserved for events affecting multiple countries or entire global markets."""
def build_event_classification_prompt(text: str) -> dict[str, str]:
"""Build system and user prompts for Ollama event classification.
Args:
text: Normalized text content of the macro news article.
Returns:
Dict with 'system' and 'user' prompt strings.
"""
# Truncate long articles to reduce inference time
max_chars = 6000
if len(text) > max_chars:
text = text[:max_chars] + "\n[... truncated ...]"
user_prompt = f"""\
Classify this global news article as a macro event. Fill every field.
{_ANTI_HALLUCINATION_RULES}
Classify the event by:
- event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \
regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk)
- severity: low, moderate, high, or critical
- affected_regions: ISO country codes or region names
- affected_sectors: GICS sector names
- affected_commodities: commodity identifiers (empty list if none)
- summary: 1-3 sentence summary of the event and market implications
- key_facts: facts explicitly stated in the text (NO fabrication)
- estimated_duration: short_term, medium_term, or long_term
- confidence: 0.0-1.0 your confidence in this classification
--- ARTICLE TEXT ---
{text}
--- END ARTICLE TEXT ---"""
return {
"system": _SYSTEM_PROMPT,
"user": user_prompt,
}
# ---------------------------------------------------------------------------
# Classification response parsing and normalization
# ---------------------------------------------------------------------------
def _normalize_event_types(raw: list[Any]) -> list[str]:
"""Normalize and filter event_types to valid ImpactType values."""
result = []
for item in raw:
val = str(item).lower().strip()
if val in _VALID_IMPACT_TYPES:
result.append(val)
return result if result else ["geopolitical_risk"]
def _normalize_severity(raw: str) -> str:
"""Normalize severity to a valid SeverityLevel value."""
val = str(raw).lower().strip()
return val if val in _VALID_SEVERITY_LEVELS else "low"
def _normalize_duration(raw: str) -> str:
"""Normalize estimated_duration to a valid EstimatedDuration value."""
val = str(raw).lower().strip()
return val if val in _VALID_DURATIONS else "short_term"
def _parse_classification_response(
raw_json: str,
document_id: str,
model_name: str,
) -> GlobalEvent:
"""Parse raw Ollama JSON output into a GlobalEvent.
Strips markdown fences and repairs malformed JSON before parsing.
Normalizes enum values and clamps numeric fields.
"""
from services.extractor.client import _repair_json, _strip_markdown_fences
cleaned = _strip_markdown_fences(raw_json)
cleaned = _repair_json(cleaned)
# DEBUG: log raw vs cleaned to diagnose persistent list issue
logger.info(
"Classification parse debug doc=%s raw_len=%d cleaned_len=%d raw_start=%s cleaned_start=%s",
document_id, len(raw_json), len(cleaned),
repr(raw_json[:300]), repr(cleaned[:300]),
)
data = json.loads(cleaned)
# Model sometimes wraps the object in a single-element list — unwrap it
if isinstance(data, list):
if len(data) == 1 and isinstance(data[0], dict):
data = data[0]
elif len(data) == 0:
raise ValueError(
f"Empty list from model for document {document_id}. "
f"Raw output ({len(raw_json)} chars): {raw_json[:500]}"
)
if not isinstance(data, dict):
raise ValueError(
f"Expected a JSON object, got {type(data).__name__} for document {document_id}. "
f"Raw output ({len(raw_json)} chars): {raw_json[:500]}"
)
confidence = data.get("confidence", 0.5)
if isinstance(confidence, (int, float)):
confidence = max(0.0, min(1.0, float(confidence)))
else:
confidence = 0.5
summary = str(data.get("summary", "")).strip()
key_facts = [str(f) for f in data.get("key_facts", []) if str(f).strip()]
# Reject empty classifications — the LLM produced no useful output
if not summary and not key_facts:
raise ValueError(
f"Empty classification for document {document_id}: "
"no summary and no key facts"
)
return GlobalEvent(
event_id=str(uuid.uuid4()),
event_types=_normalize_event_types(data.get("event_types", [])),
severity=_normalize_severity(data.get("severity", "low")),
affected_regions=[str(r) for r in data.get("affected_regions", [])],
affected_sectors=[str(s) for s in data.get("affected_sectors", [])],
affected_commodities=[str(c) for c in data.get("affected_commodities", [])],
summary=summary,
key_facts=key_facts,
estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")),
confidence=confidence,
source_document_id=document_id,
model_metadata=ModelMetadata(
provider="ollama",
model_name=model_name,
prompt_version=PROMPT_VERSION,
schema_version=SCHEMA_VERSION,
),
)
# ---------------------------------------------------------------------------
# MinIO persistence helpers
# ---------------------------------------------------------------------------
def _upload_classification_prompt(
minio_client: Minio,
document_id: str,
prompt_data: dict[str, str],
model_name: str,
timestamp: datetime | None = None,
) -> str:
"""Upload classification prompt and metadata to stonks-llm-prompts."""
ts = timestamp or datetime.now(timezone.utc)
payload = json.dumps({
"prompt_version": PROMPT_VERSION,
"schema_version": SCHEMA_VERSION,
"model": model_name,
"system_prompt": prompt_data["system"],
"user_prompt": prompt_data["user"],
"json_schema": get_event_json_schema(),
}, indent=2).encode()
path = (
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/prompt.json"
)
return upload_artifact(
minio_client, _prefixed("stonks-llm-prompts"), path, payload,
content_type="application/json",
)
def _upload_classification_result(
minio_client: Minio,
document_id: str,
raw_output: str,
event: GlobalEvent | None,
success: bool,
error: str | None,
timestamp: datetime | None = None,
) -> str:
"""Upload raw classification output to stonks-llm-results."""
ts = timestamp or datetime.now(timezone.utc)
payload = json.dumps({
"document_id": document_id,
"success": success,
"error": error,
"raw_output": raw_output,
"parsed_event": {
"event_id": event.event_id,
"event_types": event.event_types,
"severity": event.severity,
"affected_regions": event.affected_regions,
"affected_sectors": event.affected_sectors,
"affected_commodities": event.affected_commodities,
"summary": event.summary,
"key_facts": event.key_facts,
"estimated_duration": event.estimated_duration,
"confidence": event.confidence,
} if event else None,
}, indent=2).encode()
path = (
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/result.json"
)
return upload_artifact(
minio_client, _prefixed("stonks-llm-results"), path, payload,
content_type="application/json",
)
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
async def persist_global_event(
pool: asyncpg.Pool,
event: GlobalEvent,
) -> str:
"""Persist a GlobalEvent record to the global_events PostgreSQL table.
Returns the event row UUID.
"""
row_id = await pool.fetchval(
"""INSERT INTO global_events
(id, event_types, severity, affected_regions, affected_sectors,
affected_commodities, summary, key_facts, estimated_duration,
confidence, source_document_id, model_provider, model_name,
prompt_version, schema_version)
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10,
$11::uuid, $12, $13, $14, $15)
RETURNING id""",
event.event_id,
event.event_types,
event.severity,
event.affected_regions,
event.affected_sectors,
event.affected_commodities,
event.summary,
json.dumps(event.key_facts),
event.estimated_duration,
event.confidence,
event.source_document_id,
event.model_metadata.provider,
event.model_metadata.model_name,
event.model_metadata.prompt_version,
event.model_metadata.schema_version,
)
logger.info(
"Persisted global event %s for doc %s (severity=%s, types=%s)",
row_id, event.source_document_id, event.severity, event.event_types,
)
return str(row_id)
# ---------------------------------------------------------------------------
# Main classification function
# ---------------------------------------------------------------------------
async def classify_global_event(
normalized_text: str,
document_id: str,
ollama_client: Any,
*,
pool: asyncpg.Pool | None = None,
minio_client: Minio | None = None,
) -> GlobalEvent:
"""Classify a macro news article into a GlobalEvent using Ollama.
Uses the existing OllamaClient's streaming infrastructure with a
dedicated event classification prompt and JSON schema. Follows the
same retry policy as document extraction.
Resolves runtime config for the "event-classifier" agent slug from
the database, preferring an active variant's model_name and
system_prompt if one exists. Falls back to the OllamaClient's
existing config if resolution fails.
Persists prompt, raw output, and final event to MinIO and PostgreSQL
when the respective clients are provided.
Args:
normalized_text: Cleaned text content of the macro article.
document_id: UUID of the source document.
ollama_client: An OllamaClient instance (from services.extractor.client).
pool: Optional asyncpg pool for PostgreSQL persistence.
minio_client: Optional MinIO client for artifact persistence.
Returns:
A GlobalEvent with the classification result.
Raises:
ValueError: If classification fails after all retries.
"""
ts = datetime.now(timezone.utc)
start_time = time.monotonic()
# Resolve event-classifier config from DB for variant override
resolved: ResolvedAgentConfig | None = None
if pool is not None:
try:
resolver = AgentConfigResolver(pool, ttl_seconds=60)
resolved = await resolver.resolve("event-classifier")
except Exception:
logger.warning(
"Failed to resolve event-classifier config — using defaults",
exc_info=True,
)
prompts = build_event_classification_prompt(normalized_text)
json_schema = get_event_json_schema()
model_name = ollama_client._config.model
# Override model_name and system_prompt from resolved config
if resolved is not None:
model_name = resolved.model_name
if resolved.system_prompt:
prompts["system"] = resolved.system_prompt
# Ensure JSON output instruction is always present regardless
# of custom system prompt (prevents YAML/bullet-point output)
if "json" not in prompts["system"].lower():
prompts["system"] += (
"\n\nReturn ONLY a single JSON object. "
"No markdown, no explanation."
)
# Input token limit truncation
if resolved is not None and resolved.input_token_limit > 0:
max_chars = resolved.input_token_limit * 4
if len(normalized_text) > max_chars:
normalized_text = normalized_text[:max_chars]
# Rebuild prompts with truncated text
prompts = build_event_classification_prompt(normalized_text)
if resolved.system_prompt:
prompts["system"] = resolved.system_prompt
# Persist prompt to MinIO
if minio_client:
try:
_upload_classification_prompt(
minio_client, document_id, prompts, model_name, timestamp=ts,
)
except Exception:
logger.exception("Failed to upload classification prompt for doc %s", document_id)
# Call Ollama using the client's internal _call_ollama method
# We reuse the retry logic pattern from OllamaClient.extract()
max_retries = ollama_client._max_retries
if resolved is not None:
max_retries = resolved.max_retries
last_error: str | None = None
raw_output = ""
for attempt_num in range(max_retries + 1):
attempt = await ollama_client._call_ollama(prompts, json_schema)
raw_output = attempt.raw_output
# _call_ollama validates against the *extraction* schema, which
# doesn't match event classification output. We ignore that
# validation error and try our own parsing whenever we have output.
if raw_output:
# Try to parse the response
try:
event = _parse_classification_response(
raw_output, document_id, model_name,
)
# Persist result to MinIO
if minio_client:
try:
_upload_classification_result(
minio_client, document_id, raw_output,
event, success=True, error=None, timestamp=ts,
)
except Exception:
logger.exception(
"Failed to upload classification result for doc %s", document_id,
)
# Persist to PostgreSQL
if pool:
try:
await persist_global_event(pool, event)
except Exception:
logger.exception(
"Failed to persist global event for doc %s", document_id,
)
# Log to agent_performance_log with variant attribution
if pool is not None and resolved is not None:
duration_ms = int((time.monotonic() - start_time) * 1000)
output_tokens = len(raw_output) // 4 if raw_output else 0
try:
await pool.execute(
"""INSERT INTO agent_performance_log
(agent_id, variant_id, document_id, ticker, success,
duration_ms, confidence, retry_count,
input_tokens, output_tokens, error_message)
VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5,
$6, $7, $8, $9, $10, $11)""",
resolved.agent_id,
resolved.variant_id,
document_id,
"",
True,
duration_ms,
event.confidence,
attempt_num,
len(normalized_text) // 4,
output_tokens,
None,
)
except Exception:
logger.warning(
"Failed to log event-classifier performance for doc %s",
document_id, exc_info=True,
)
return event
except (json.JSONDecodeError, KeyError, TypeError, ValueError) as exc:
last_error = f"parse_error: {exc}"
logger.warning(
"Classification parse error for doc %s attempt %d: %s",
document_id, attempt_num + 1, exc,
)
else:
last_error = attempt.error or "empty_response"
# Retry with backoff
if attempt_num < max_retries:
delay = ollama_client._base_delay * (
ollama_client._backoff_multiplier ** attempt_num
)
delay = min(delay, ollama_client._max_delay)
logger.warning(
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
)
await asyncio.sleep(delay)
# All retries exhausted — log failure performance and persist
if pool is not None and resolved is not None:
duration_ms = int((time.monotonic() - start_time) * 1000)
try:
await pool.execute(
"""INSERT INTO agent_performance_log
(agent_id, variant_id, document_id, ticker, success,
duration_ms, confidence, retry_count,
input_tokens, output_tokens, error_message)
VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5,
$6, $7, $8, $9, $10, $11)""",
resolved.agent_id,
resolved.variant_id,
document_id,
"",
False,
duration_ms,
0.0,
max_retries + 1,
len(normalized_text) // 4,
0,
last_error,
)
except Exception:
logger.warning(
"Failed to log event-classifier failure performance for doc %s",
document_id, exc_info=True,
)
if minio_client:
try:
_upload_classification_result(
minio_client, document_id, raw_output,
event=None, success=False, error=last_error, timestamp=ts,
)
except Exception:
logger.exception(
"Failed to upload failed classification result for doc %s", document_id,
)
raise ValueError(
f"Event classification failed for document {document_id} "
f"after {max_retries + 1} attempts: {last_error}"
)