feat: competitive intelligence & historical pattern matching layer
This commit is contained in:
@@ -0,0 +1,549 @@
|
||||
"""Event classifier module for macro news articles.
|
||||
|
||||
Classifies global/geopolitical news articles into structured GlobalEvent
|
||||
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
|
||||
existing OllamaClient for inference and retry logic.
|
||||
|
||||
Persists classification prompts, raw outputs, and final events to MinIO
|
||||
and PostgreSQL for audit and downstream interpolation.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.shared.schemas import (
|
||||
EstimatedDuration,
|
||||
ImpactType,
|
||||
ModelMetadata,
|
||||
SeverityLevel,
|
||||
)
|
||||
from services.shared.storage import upload_artifact
|
||||
|
||||
logger = logging.getLogger("event_classifier")
|
||||
|
||||
PROMPT_VERSION = "event-classification-v1"
|
||||
SCHEMA_VERSION = "1.0.0"
|
||||
|
||||
# Valid enum value sets for normalization
|
||||
_VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType)
|
||||
_VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel)
|
||||
_VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GlobalEvent dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class GlobalEvent:
|
||||
"""Structured classification of a macro news event.
|
||||
|
||||
Produced by the event classifier from Ollama structured output.
|
||||
"""
|
||||
|
||||
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
event_types: list[str] = field(default_factory=list)
|
||||
severity: str = "low"
|
||||
affected_regions: list[str] = field(default_factory=list)
|
||||
affected_sectors: list[str] = field(default_factory=list)
|
||||
affected_commodities: list[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
key_facts: list[str] = field(default_factory=list)
|
||||
estimated_duration: str = "short_term"
|
||||
confidence: float = 0.5
|
||||
source_document_id: str = ""
|
||||
model_metadata: ModelMetadata = field(default_factory=ModelMetadata)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON schema for Ollama structured output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _EventClassificationResult:
|
||||
"""Schema definition for the Ollama event classification response.
|
||||
|
||||
Not a Pydantic model — we build the JSON schema dict directly to keep
|
||||
it self-contained and Ollama-friendly (no $refs).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_event_json_schema() -> dict[str, Any]:
|
||||
"""Return the JSON schema for Ollama structured event classification output.
|
||||
|
||||
The schema forces the model to produce all required fields explicitly.
|
||||
"""
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"event_types",
|
||||
"severity",
|
||||
"affected_regions",
|
||||
"affected_sectors",
|
||||
"affected_commodities",
|
||||
"summary",
|
||||
"key_facts",
|
||||
"estimated_duration",
|
||||
"confidence",
|
||||
],
|
||||
"properties": {
|
||||
"event_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_IMPACT_TYPES),
|
||||
},
|
||||
"description": (
|
||||
"One or more impact types this event represents. "
|
||||
"Include ALL applicable types — do not collapse to a single category."
|
||||
),
|
||||
},
|
||||
"severity": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_SEVERITY_LEVELS),
|
||||
"description": "Overall severity of the event: low, moderate, high, or critical.",
|
||||
},
|
||||
"affected_regions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"ISO 3166-1 alpha-2 country codes or region names affected. "
|
||||
"Use standard codes like US, CN, EU, GB, JP. "
|
||||
"Only include regions explicitly mentioned or clearly implied."
|
||||
),
|
||||
},
|
||||
"affected_sectors": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"GICS sector identifiers or sector names affected. "
|
||||
"Examples: Energy, Materials, Industrials, Consumer Discretionary, "
|
||||
"Consumer Staples, Health Care, Financials, Information Technology, "
|
||||
"Communication Services, Utilities, Real Estate."
|
||||
),
|
||||
},
|
||||
"affected_commodities": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Commodity identifiers affected, if applicable. "
|
||||
"Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, "
|
||||
"semiconductors. Empty list if no commodities are directly affected."
|
||||
),
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A concise 1-3 sentence summary of the event and its market implications.",
|
||||
},
|
||||
"key_facts": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Key facts explicitly stated in the article. "
|
||||
"Do NOT infer, speculate, or fabricate facts. "
|
||||
"Each fact must be directly supported by the text."
|
||||
),
|
||||
},
|
||||
"estimated_duration": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_DURATIONS),
|
||||
"description": (
|
||||
"Expected duration of market impact: "
|
||||
"short_term (days to weeks), medium_term (weeks to months), "
|
||||
"long_term (months to years)."
|
||||
),
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0.0,
|
||||
"maximum": 1.0,
|
||||
"description": (
|
||||
"Your confidence in this classification. "
|
||||
"Lower if the article is ambiguous, speculative, or lacks concrete details."
|
||||
),
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SYSTEM_PROMPT = """\
|
||||
You classify global news articles into structured macro event intelligence. \
|
||||
Read the article carefully and extract the event classification. \
|
||||
Return ONLY valid JSON matching the schema. No commentary, no markdown, no explanation."""
|
||||
|
||||
_ANTI_HALLUCINATION_RULES = """\
|
||||
CRITICAL RULES — read carefully:
|
||||
1. Only extract information EXPLICITLY stated in the article text.
|
||||
2. Do NOT infer, speculate, or fabricate facts, regions, sectors, or commodities.
|
||||
3. If the article mentions multiple distinct impact types, include ALL of them in event_types.
|
||||
4. For affected_regions, only include regions explicitly mentioned or clearly implied by the event.
|
||||
5. For affected_sectors, only include sectors with a clear causal link to the event.
|
||||
6. For affected_commodities, only include commodities directly referenced or obviously impacted.
|
||||
7. For key_facts, each fact must be directly supported by a specific passage in the text.
|
||||
8. If the article is vague or speculative, set confidence LOW (below 0.4).
|
||||
9. Do NOT treat journalist speculation or opinion as confirmed fact.
|
||||
10. Distinguish between announced policy and proposed/rumored policy."""
|
||||
|
||||
|
||||
def build_event_classification_prompt(text: str) -> dict[str, str]:
|
||||
"""Build system and user prompts for Ollama event classification.
|
||||
|
||||
Args:
|
||||
text: Normalized text content of the macro news article.
|
||||
|
||||
Returns:
|
||||
Dict with 'system' and 'user' prompt strings.
|
||||
"""
|
||||
user_prompt = f"""\
|
||||
Classify this global news article as a macro event. Fill every field.
|
||||
|
||||
{_ANTI_HALLUCINATION_RULES}
|
||||
|
||||
Classify the event by:
|
||||
- event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \
|
||||
regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk)
|
||||
- severity: low, moderate, high, or critical
|
||||
- affected_regions: ISO country codes or region names
|
||||
- affected_sectors: GICS sector names
|
||||
- affected_commodities: commodity identifiers (empty list if none)
|
||||
- summary: 1-3 sentence summary of the event and market implications
|
||||
- key_facts: facts explicitly stated in the text (NO fabrication)
|
||||
- estimated_duration: short_term, medium_term, or long_term
|
||||
- confidence: 0.0-1.0 your confidence in this classification
|
||||
|
||||
--- ARTICLE TEXT ---
|
||||
{text}
|
||||
--- END ARTICLE TEXT ---"""
|
||||
|
||||
return {
|
||||
"system": _SYSTEM_PROMPT,
|
||||
"user": user_prompt,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Classification response parsing and normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_event_types(raw: list[Any]) -> list[str]:
|
||||
"""Normalize and filter event_types to valid ImpactType values."""
|
||||
result = []
|
||||
for item in raw:
|
||||
val = str(item).lower().strip()
|
||||
if val in _VALID_IMPACT_TYPES:
|
||||
result.append(val)
|
||||
return result if result else ["geopolitical_risk"]
|
||||
|
||||
|
||||
def _normalize_severity(raw: str) -> str:
|
||||
"""Normalize severity to a valid SeverityLevel value."""
|
||||
val = str(raw).lower().strip()
|
||||
return val if val in _VALID_SEVERITY_LEVELS else "low"
|
||||
|
||||
|
||||
def _normalize_duration(raw: str) -> str:
|
||||
"""Normalize estimated_duration to a valid EstimatedDuration value."""
|
||||
val = str(raw).lower().strip()
|
||||
return val if val in _VALID_DURATIONS else "short_term"
|
||||
|
||||
|
||||
def _parse_classification_response(
|
||||
raw_json: str,
|
||||
document_id: str,
|
||||
model_name: str,
|
||||
) -> GlobalEvent:
|
||||
"""Parse raw Ollama JSON output into a GlobalEvent.
|
||||
|
||||
Normalizes enum values and clamps numeric fields.
|
||||
"""
|
||||
data = json.loads(raw_json)
|
||||
|
||||
confidence = data.get("confidence", 0.5)
|
||||
if isinstance(confidence, (int, float)):
|
||||
confidence = max(0.0, min(1.0, float(confidence)))
|
||||
else:
|
||||
confidence = 0.5
|
||||
|
||||
return GlobalEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
event_types=_normalize_event_types(data.get("event_types", [])),
|
||||
severity=_normalize_severity(data.get("severity", "low")),
|
||||
affected_regions=[str(r) for r in data.get("affected_regions", [])],
|
||||
affected_sectors=[str(s) for s in data.get("affected_sectors", [])],
|
||||
affected_commodities=[str(c) for c in data.get("affected_commodities", [])],
|
||||
summary=str(data.get("summary", "")),
|
||||
key_facts=[str(f) for f in data.get("key_facts", [])],
|
||||
estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")),
|
||||
confidence=confidence,
|
||||
source_document_id=document_id,
|
||||
model_metadata=ModelMetadata(
|
||||
provider="ollama",
|
||||
model_name=model_name,
|
||||
prompt_version=PROMPT_VERSION,
|
||||
schema_version=SCHEMA_VERSION,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MinIO persistence helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _upload_classification_prompt(
|
||||
minio_client: Minio,
|
||||
document_id: str,
|
||||
prompt_data: dict[str, str],
|
||||
model_name: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> str:
|
||||
"""Upload classification prompt and metadata to stonks-llm-prompts."""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
payload = json.dumps({
|
||||
"prompt_version": PROMPT_VERSION,
|
||||
"schema_version": SCHEMA_VERSION,
|
||||
"model": model_name,
|
||||
"system_prompt": prompt_data["system"],
|
||||
"user_prompt": prompt_data["user"],
|
||||
"json_schema": get_event_json_schema(),
|
||||
}, indent=2).encode()
|
||||
|
||||
path = (
|
||||
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/prompt.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
minio_client, "stonks-llm-prompts", path, payload,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
def _upload_classification_result(
|
||||
minio_client: Minio,
|
||||
document_id: str,
|
||||
raw_output: str,
|
||||
event: GlobalEvent | None,
|
||||
success: bool,
|
||||
error: str | None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> str:
|
||||
"""Upload raw classification output to stonks-llm-results."""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
payload = json.dumps({
|
||||
"document_id": document_id,
|
||||
"success": success,
|
||||
"error": error,
|
||||
"raw_output": raw_output,
|
||||
"parsed_event": {
|
||||
"event_id": event.event_id,
|
||||
"event_types": event.event_types,
|
||||
"severity": event.severity,
|
||||
"affected_regions": event.affected_regions,
|
||||
"affected_sectors": event.affected_sectors,
|
||||
"affected_commodities": event.affected_commodities,
|
||||
"summary": event.summary,
|
||||
"key_facts": event.key_facts,
|
||||
"estimated_duration": event.estimated_duration,
|
||||
"confidence": event.confidence,
|
||||
} if event else None,
|
||||
}, indent=2).encode()
|
||||
|
||||
path = (
|
||||
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/result.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
minio_client, "stonks-llm-results", path, payload,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def persist_global_event(
|
||||
pool: asyncpg.Pool,
|
||||
event: GlobalEvent,
|
||||
) -> str:
|
||||
"""Persist a GlobalEvent record to the global_events PostgreSQL table.
|
||||
|
||||
Returns the event row UUID.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO global_events
|
||||
(id, event_types, severity, affected_regions, affected_sectors,
|
||||
affected_commodities, summary, key_facts, estimated_duration,
|
||||
confidence, source_document_id, model_provider, model_name,
|
||||
prompt_version, schema_version)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10,
|
||||
$11::uuid, $12, $13, $14, $15)
|
||||
RETURNING id""",
|
||||
event.event_id,
|
||||
event.event_types,
|
||||
event.severity,
|
||||
event.affected_regions,
|
||||
event.affected_sectors,
|
||||
event.affected_commodities,
|
||||
event.summary,
|
||||
json.dumps(event.key_facts),
|
||||
event.estimated_duration,
|
||||
event.confidence,
|
||||
event.source_document_id,
|
||||
event.model_metadata.provider,
|
||||
event.model_metadata.model_name,
|
||||
event.model_metadata.prompt_version,
|
||||
event.model_metadata.schema_version,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted global event %s for doc %s (severity=%s, types=%s)",
|
||||
row_id, event.source_document_id, event.severity, event.event_types,
|
||||
)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main classification function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def classify_global_event(
|
||||
normalized_text: str,
|
||||
document_id: str,
|
||||
ollama_client: Any,
|
||||
*,
|
||||
pool: asyncpg.Pool | None = None,
|
||||
minio_client: Minio | None = None,
|
||||
) -> GlobalEvent:
|
||||
"""Classify a macro news article into a GlobalEvent using Ollama.
|
||||
|
||||
Uses the existing OllamaClient's streaming infrastructure with a
|
||||
dedicated event classification prompt and JSON schema. Follows the
|
||||
same retry policy as document extraction.
|
||||
|
||||
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
||||
when the respective clients are provided.
|
||||
|
||||
Args:
|
||||
normalized_text: Cleaned text content of the macro article.
|
||||
document_id: UUID of the source document.
|
||||
ollama_client: An OllamaClient instance (from services.extractor.client).
|
||||
pool: Optional asyncpg pool for PostgreSQL persistence.
|
||||
minio_client: Optional MinIO client for artifact persistence.
|
||||
|
||||
Returns:
|
||||
A GlobalEvent with the classification result.
|
||||
|
||||
Raises:
|
||||
ValueError: If classification fails after all retries.
|
||||
"""
|
||||
ts = datetime.now(timezone.utc)
|
||||
prompts = build_event_classification_prompt(normalized_text)
|
||||
json_schema = get_event_json_schema()
|
||||
model_name = ollama_client._config.model
|
||||
|
||||
# Persist prompt to MinIO
|
||||
prompt_ref = None
|
||||
if minio_client:
|
||||
try:
|
||||
prompt_ref = _upload_classification_prompt(
|
||||
minio_client, document_id, prompts, model_name, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
||||
|
||||
# Call Ollama using the client's internal _call_ollama method
|
||||
# We reuse the retry logic pattern from OllamaClient.extract()
|
||||
max_retries = ollama_client._max_retries
|
||||
last_error: str | None = None
|
||||
raw_output = ""
|
||||
|
||||
for attempt_num in range(max_retries + 1):
|
||||
attempt = await ollama_client._call_ollama(prompts, json_schema)
|
||||
raw_output = attempt.raw_output
|
||||
|
||||
if attempt.error is None and raw_output:
|
||||
# Try to parse the response
|
||||
try:
|
||||
event = _parse_classification_response(
|
||||
raw_output, document_id, model_name,
|
||||
)
|
||||
|
||||
# Persist result to MinIO
|
||||
if minio_client:
|
||||
try:
|
||||
_upload_classification_result(
|
||||
minio_client, document_id, raw_output,
|
||||
event, success=True, error=None, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to upload classification result for doc %s", document_id,
|
||||
)
|
||||
|
||||
# Persist to PostgreSQL
|
||||
if pool:
|
||||
try:
|
||||
await persist_global_event(pool, event)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to persist global event for doc %s", document_id,
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
||||
last_error = f"parse_error: {exc}"
|
||||
logger.warning(
|
||||
"Classification parse error for doc %s attempt %d: %s",
|
||||
document_id, attempt_num + 1, exc,
|
||||
)
|
||||
else:
|
||||
last_error = attempt.error or "empty_response"
|
||||
|
||||
# Retry with backoff
|
||||
if attempt_num < max_retries:
|
||||
delay = ollama_client._base_delay * (
|
||||
ollama_client._backoff_multiplier ** attempt_num
|
||||
)
|
||||
delay = min(delay, ollama_client._max_delay)
|
||||
logger.warning(
|
||||
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All retries exhausted — persist failure and raise
|
||||
if minio_client:
|
||||
try:
|
||||
_upload_classification_result(
|
||||
minio_client, document_id, raw_output,
|
||||
event=None, success=False, error=last_error, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to upload failed classification result for doc %s", document_id,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Event classification failed for document {document_id} "
|
||||
f"after {max_retries + 1} attempts: {last_error}"
|
||||
)
|
||||
Reference in New Issue
Block a user