feat: competitive intelligence & historical pattern matching layer
This commit is contained in:
@@ -0,0 +1,549 @@
|
||||
"""Event classifier module for macro news articles.
|
||||
|
||||
Classifies global/geopolitical news articles into structured GlobalEvent
|
||||
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
|
||||
existing OllamaClient for inference and retry logic.
|
||||
|
||||
Persists classification prompts, raw outputs, and final events to MinIO
|
||||
and PostgreSQL for audit and downstream interpolation.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.shared.schemas import (
|
||||
EstimatedDuration,
|
||||
ImpactType,
|
||||
ModelMetadata,
|
||||
SeverityLevel,
|
||||
)
|
||||
from services.shared.storage import upload_artifact
|
||||
|
||||
logger = logging.getLogger("event_classifier")
|
||||
|
||||
PROMPT_VERSION = "event-classification-v1"
|
||||
SCHEMA_VERSION = "1.0.0"
|
||||
|
||||
# Valid enum value sets for normalization
|
||||
_VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType)
|
||||
_VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel)
|
||||
_VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GlobalEvent dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class GlobalEvent:
|
||||
"""Structured classification of a macro news event.
|
||||
|
||||
Produced by the event classifier from Ollama structured output.
|
||||
"""
|
||||
|
||||
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
event_types: list[str] = field(default_factory=list)
|
||||
severity: str = "low"
|
||||
affected_regions: list[str] = field(default_factory=list)
|
||||
affected_sectors: list[str] = field(default_factory=list)
|
||||
affected_commodities: list[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
key_facts: list[str] = field(default_factory=list)
|
||||
estimated_duration: str = "short_term"
|
||||
confidence: float = 0.5
|
||||
source_document_id: str = ""
|
||||
model_metadata: ModelMetadata = field(default_factory=ModelMetadata)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON schema for Ollama structured output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _EventClassificationResult:
|
||||
"""Schema definition for the Ollama event classification response.
|
||||
|
||||
Not a Pydantic model — we build the JSON schema dict directly to keep
|
||||
it self-contained and Ollama-friendly (no $refs).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_event_json_schema() -> dict[str, Any]:
|
||||
"""Return the JSON schema for Ollama structured event classification output.
|
||||
|
||||
The schema forces the model to produce all required fields explicitly.
|
||||
"""
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"event_types",
|
||||
"severity",
|
||||
"affected_regions",
|
||||
"affected_sectors",
|
||||
"affected_commodities",
|
||||
"summary",
|
||||
"key_facts",
|
||||
"estimated_duration",
|
||||
"confidence",
|
||||
],
|
||||
"properties": {
|
||||
"event_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_IMPACT_TYPES),
|
||||
},
|
||||
"description": (
|
||||
"One or more impact types this event represents. "
|
||||
"Include ALL applicable types — do not collapse to a single category."
|
||||
),
|
||||
},
|
||||
"severity": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_SEVERITY_LEVELS),
|
||||
"description": "Overall severity of the event: low, moderate, high, or critical.",
|
||||
},
|
||||
"affected_regions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"ISO 3166-1 alpha-2 country codes or region names affected. "
|
||||
"Use standard codes like US, CN, EU, GB, JP. "
|
||||
"Only include regions explicitly mentioned or clearly implied."
|
||||
),
|
||||
},
|
||||
"affected_sectors": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"GICS sector identifiers or sector names affected. "
|
||||
"Examples: Energy, Materials, Industrials, Consumer Discretionary, "
|
||||
"Consumer Staples, Health Care, Financials, Information Technology, "
|
||||
"Communication Services, Utilities, Real Estate."
|
||||
),
|
||||
},
|
||||
"affected_commodities": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Commodity identifiers affected, if applicable. "
|
||||
"Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, "
|
||||
"semiconductors. Empty list if no commodities are directly affected."
|
||||
),
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A concise 1-3 sentence summary of the event and its market implications.",
|
||||
},
|
||||
"key_facts": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Key facts explicitly stated in the article. "
|
||||
"Do NOT infer, speculate, or fabricate facts. "
|
||||
"Each fact must be directly supported by the text."
|
||||
),
|
||||
},
|
||||
"estimated_duration": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_DURATIONS),
|
||||
"description": (
|
||||
"Expected duration of market impact: "
|
||||
"short_term (days to weeks), medium_term (weeks to months), "
|
||||
"long_term (months to years)."
|
||||
),
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0.0,
|
||||
"maximum": 1.0,
|
||||
"description": (
|
||||
"Your confidence in this classification. "
|
||||
"Lower if the article is ambiguous, speculative, or lacks concrete details."
|
||||
),
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SYSTEM_PROMPT = """\
|
||||
You classify global news articles into structured macro event intelligence. \
|
||||
Read the article carefully and extract the event classification. \
|
||||
Return ONLY valid JSON matching the schema. No commentary, no markdown, no explanation."""
|
||||
|
||||
_ANTI_HALLUCINATION_RULES = """\
|
||||
CRITICAL RULES — read carefully:
|
||||
1. Only extract information EXPLICITLY stated in the article text.
|
||||
2. Do NOT infer, speculate, or fabricate facts, regions, sectors, or commodities.
|
||||
3. If the article mentions multiple distinct impact types, include ALL of them in event_types.
|
||||
4. For affected_regions, only include regions explicitly mentioned or clearly implied by the event.
|
||||
5. For affected_sectors, only include sectors with a clear causal link to the event.
|
||||
6. For affected_commodities, only include commodities directly referenced or obviously impacted.
|
||||
7. For key_facts, each fact must be directly supported by a specific passage in the text.
|
||||
8. If the article is vague or speculative, set confidence LOW (below 0.4).
|
||||
9. Do NOT treat journalist speculation or opinion as confirmed fact.
|
||||
10. Distinguish between announced policy and proposed/rumored policy."""
|
||||
|
||||
|
||||
def build_event_classification_prompt(text: str) -> dict[str, str]:
|
||||
"""Build system and user prompts for Ollama event classification.
|
||||
|
||||
Args:
|
||||
text: Normalized text content of the macro news article.
|
||||
|
||||
Returns:
|
||||
Dict with 'system' and 'user' prompt strings.
|
||||
"""
|
||||
user_prompt = f"""\
|
||||
Classify this global news article as a macro event. Fill every field.
|
||||
|
||||
{_ANTI_HALLUCINATION_RULES}
|
||||
|
||||
Classify the event by:
|
||||
- event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \
|
||||
regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk)
|
||||
- severity: low, moderate, high, or critical
|
||||
- affected_regions: ISO country codes or region names
|
||||
- affected_sectors: GICS sector names
|
||||
- affected_commodities: commodity identifiers (empty list if none)
|
||||
- summary: 1-3 sentence summary of the event and market implications
|
||||
- key_facts: facts explicitly stated in the text (NO fabrication)
|
||||
- estimated_duration: short_term, medium_term, or long_term
|
||||
- confidence: 0.0-1.0 your confidence in this classification
|
||||
|
||||
--- ARTICLE TEXT ---
|
||||
{text}
|
||||
--- END ARTICLE TEXT ---"""
|
||||
|
||||
return {
|
||||
"system": _SYSTEM_PROMPT,
|
||||
"user": user_prompt,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Classification response parsing and normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_event_types(raw: list[Any]) -> list[str]:
|
||||
"""Normalize and filter event_types to valid ImpactType values."""
|
||||
result = []
|
||||
for item in raw:
|
||||
val = str(item).lower().strip()
|
||||
if val in _VALID_IMPACT_TYPES:
|
||||
result.append(val)
|
||||
return result if result else ["geopolitical_risk"]
|
||||
|
||||
|
||||
def _normalize_severity(raw: str) -> str:
|
||||
"""Normalize severity to a valid SeverityLevel value."""
|
||||
val = str(raw).lower().strip()
|
||||
return val if val in _VALID_SEVERITY_LEVELS else "low"
|
||||
|
||||
|
||||
def _normalize_duration(raw: str) -> str:
|
||||
"""Normalize estimated_duration to a valid EstimatedDuration value."""
|
||||
val = str(raw).lower().strip()
|
||||
return val if val in _VALID_DURATIONS else "short_term"
|
||||
|
||||
|
||||
def _parse_classification_response(
|
||||
raw_json: str,
|
||||
document_id: str,
|
||||
model_name: str,
|
||||
) -> GlobalEvent:
|
||||
"""Parse raw Ollama JSON output into a GlobalEvent.
|
||||
|
||||
Normalizes enum values and clamps numeric fields.
|
||||
"""
|
||||
data = json.loads(raw_json)
|
||||
|
||||
confidence = data.get("confidence", 0.5)
|
||||
if isinstance(confidence, (int, float)):
|
||||
confidence = max(0.0, min(1.0, float(confidence)))
|
||||
else:
|
||||
confidence = 0.5
|
||||
|
||||
return GlobalEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
event_types=_normalize_event_types(data.get("event_types", [])),
|
||||
severity=_normalize_severity(data.get("severity", "low")),
|
||||
affected_regions=[str(r) for r in data.get("affected_regions", [])],
|
||||
affected_sectors=[str(s) for s in data.get("affected_sectors", [])],
|
||||
affected_commodities=[str(c) for c in data.get("affected_commodities", [])],
|
||||
summary=str(data.get("summary", "")),
|
||||
key_facts=[str(f) for f in data.get("key_facts", [])],
|
||||
estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")),
|
||||
confidence=confidence,
|
||||
source_document_id=document_id,
|
||||
model_metadata=ModelMetadata(
|
||||
provider="ollama",
|
||||
model_name=model_name,
|
||||
prompt_version=PROMPT_VERSION,
|
||||
schema_version=SCHEMA_VERSION,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MinIO persistence helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _upload_classification_prompt(
|
||||
minio_client: Minio,
|
||||
document_id: str,
|
||||
prompt_data: dict[str, str],
|
||||
model_name: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> str:
|
||||
"""Upload classification prompt and metadata to stonks-llm-prompts."""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
payload = json.dumps({
|
||||
"prompt_version": PROMPT_VERSION,
|
||||
"schema_version": SCHEMA_VERSION,
|
||||
"model": model_name,
|
||||
"system_prompt": prompt_data["system"],
|
||||
"user_prompt": prompt_data["user"],
|
||||
"json_schema": get_event_json_schema(),
|
||||
}, indent=2).encode()
|
||||
|
||||
path = (
|
||||
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/prompt.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
minio_client, "stonks-llm-prompts", path, payload,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
def _upload_classification_result(
|
||||
minio_client: Minio,
|
||||
document_id: str,
|
||||
raw_output: str,
|
||||
event: GlobalEvent | None,
|
||||
success: bool,
|
||||
error: str | None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> str:
|
||||
"""Upload raw classification output to stonks-llm-results."""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
payload = json.dumps({
|
||||
"document_id": document_id,
|
||||
"success": success,
|
||||
"error": error,
|
||||
"raw_output": raw_output,
|
||||
"parsed_event": {
|
||||
"event_id": event.event_id,
|
||||
"event_types": event.event_types,
|
||||
"severity": event.severity,
|
||||
"affected_regions": event.affected_regions,
|
||||
"affected_sectors": event.affected_sectors,
|
||||
"affected_commodities": event.affected_commodities,
|
||||
"summary": event.summary,
|
||||
"key_facts": event.key_facts,
|
||||
"estimated_duration": event.estimated_duration,
|
||||
"confidence": event.confidence,
|
||||
} if event else None,
|
||||
}, indent=2).encode()
|
||||
|
||||
path = (
|
||||
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/result.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
minio_client, "stonks-llm-results", path, payload,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def persist_global_event(
|
||||
pool: asyncpg.Pool,
|
||||
event: GlobalEvent,
|
||||
) -> str:
|
||||
"""Persist a GlobalEvent record to the global_events PostgreSQL table.
|
||||
|
||||
Returns the event row UUID.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO global_events
|
||||
(id, event_types, severity, affected_regions, affected_sectors,
|
||||
affected_commodities, summary, key_facts, estimated_duration,
|
||||
confidence, source_document_id, model_provider, model_name,
|
||||
prompt_version, schema_version)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10,
|
||||
$11::uuid, $12, $13, $14, $15)
|
||||
RETURNING id""",
|
||||
event.event_id,
|
||||
event.event_types,
|
||||
event.severity,
|
||||
event.affected_regions,
|
||||
event.affected_sectors,
|
||||
event.affected_commodities,
|
||||
event.summary,
|
||||
json.dumps(event.key_facts),
|
||||
event.estimated_duration,
|
||||
event.confidence,
|
||||
event.source_document_id,
|
||||
event.model_metadata.provider,
|
||||
event.model_metadata.model_name,
|
||||
event.model_metadata.prompt_version,
|
||||
event.model_metadata.schema_version,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted global event %s for doc %s (severity=%s, types=%s)",
|
||||
row_id, event.source_document_id, event.severity, event.event_types,
|
||||
)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main classification function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def classify_global_event(
|
||||
normalized_text: str,
|
||||
document_id: str,
|
||||
ollama_client: Any,
|
||||
*,
|
||||
pool: asyncpg.Pool | None = None,
|
||||
minio_client: Minio | None = None,
|
||||
) -> GlobalEvent:
|
||||
"""Classify a macro news article into a GlobalEvent using Ollama.
|
||||
|
||||
Uses the existing OllamaClient's streaming infrastructure with a
|
||||
dedicated event classification prompt and JSON schema. Follows the
|
||||
same retry policy as document extraction.
|
||||
|
||||
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
||||
when the respective clients are provided.
|
||||
|
||||
Args:
|
||||
normalized_text: Cleaned text content of the macro article.
|
||||
document_id: UUID of the source document.
|
||||
ollama_client: An OllamaClient instance (from services.extractor.client).
|
||||
pool: Optional asyncpg pool for PostgreSQL persistence.
|
||||
minio_client: Optional MinIO client for artifact persistence.
|
||||
|
||||
Returns:
|
||||
A GlobalEvent with the classification result.
|
||||
|
||||
Raises:
|
||||
ValueError: If classification fails after all retries.
|
||||
"""
|
||||
ts = datetime.now(timezone.utc)
|
||||
prompts = build_event_classification_prompt(normalized_text)
|
||||
json_schema = get_event_json_schema()
|
||||
model_name = ollama_client._config.model
|
||||
|
||||
# Persist prompt to MinIO
|
||||
prompt_ref = None
|
||||
if minio_client:
|
||||
try:
|
||||
prompt_ref = _upload_classification_prompt(
|
||||
minio_client, document_id, prompts, model_name, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
||||
|
||||
# Call Ollama using the client's internal _call_ollama method
|
||||
# We reuse the retry logic pattern from OllamaClient.extract()
|
||||
max_retries = ollama_client._max_retries
|
||||
last_error: str | None = None
|
||||
raw_output = ""
|
||||
|
||||
for attempt_num in range(max_retries + 1):
|
||||
attempt = await ollama_client._call_ollama(prompts, json_schema)
|
||||
raw_output = attempt.raw_output
|
||||
|
||||
if attempt.error is None and raw_output:
|
||||
# Try to parse the response
|
||||
try:
|
||||
event = _parse_classification_response(
|
||||
raw_output, document_id, model_name,
|
||||
)
|
||||
|
||||
# Persist result to MinIO
|
||||
if minio_client:
|
||||
try:
|
||||
_upload_classification_result(
|
||||
minio_client, document_id, raw_output,
|
||||
event, success=True, error=None, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to upload classification result for doc %s", document_id,
|
||||
)
|
||||
|
||||
# Persist to PostgreSQL
|
||||
if pool:
|
||||
try:
|
||||
await persist_global_event(pool, event)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to persist global event for doc %s", document_id,
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
||||
last_error = f"parse_error: {exc}"
|
||||
logger.warning(
|
||||
"Classification parse error for doc %s attempt %d: %s",
|
||||
document_id, attempt_num + 1, exc,
|
||||
)
|
||||
else:
|
||||
last_error = attempt.error or "empty_response"
|
||||
|
||||
# Retry with backoff
|
||||
if attempt_num < max_retries:
|
||||
delay = ollama_client._base_delay * (
|
||||
ollama_client._backoff_multiplier ** attempt_num
|
||||
)
|
||||
delay = min(delay, ollama_client._max_delay)
|
||||
logger.warning(
|
||||
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All retries exhausted — persist failure and raise
|
||||
if minio_client:
|
||||
try:
|
||||
_upload_classification_result(
|
||||
minio_client, document_id, raw_output,
|
||||
event=None, success=False, error=last_error, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to upload failed classification result for doc %s", document_id,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Event classification failed for document {document_id} "
|
||||
f"after {max_retries + 1} attempts: {last_error}"
|
||||
)
|
||||
@@ -0,0 +1,394 @@
|
||||
"""Exposure profile auto-inference from filing extractions.
|
||||
|
||||
Infers baseline exposure profiles from company filing extractions when
|
||||
no manual profile exists. Scans recent filing extractions for geographic
|
||||
revenue breakdowns, supplier mentions, and commodity references.
|
||||
|
||||
Requirements: 9.1, 9.2, 9.3
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from services.aggregation.interpolation import build_default_profile
|
||||
from services.shared.schemas import (
|
||||
DocumentIntelligence,
|
||||
ExposureProfileSchema,
|
||||
MarketPositionTier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("exposure_inference")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Known region patterns for geographic extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGION_KEYWORDS: dict[str, str] = {
|
||||
"united states": "US",
|
||||
"u.s.": "US",
|
||||
"us": "US",
|
||||
"america": "US",
|
||||
"north america": "US",
|
||||
"china": "CN",
|
||||
"chinese": "CN",
|
||||
"europe": "EU",
|
||||
"european": "EU",
|
||||
"eu": "EU",
|
||||
"japan": "JP",
|
||||
"japanese": "JP",
|
||||
"germany": "DE",
|
||||
"german": "DE",
|
||||
"united kingdom": "GB",
|
||||
"uk": "GB",
|
||||
"britain": "GB",
|
||||
"british": "GB",
|
||||
"south korea": "KR",
|
||||
"korea": "KR",
|
||||
"india": "IN",
|
||||
"indian": "IN",
|
||||
"brazil": "BR",
|
||||
"brazilian": "BR",
|
||||
"australia": "AU",
|
||||
"australian": "AU",
|
||||
"canada": "CA",
|
||||
"canadian": "CA",
|
||||
"taiwan": "TW",
|
||||
"saudi arabia": "SA",
|
||||
"russia": "RU",
|
||||
"russian": "RU",
|
||||
"mexico": "MX",
|
||||
"singapore": "SG",
|
||||
"asia": "CN",
|
||||
"asia pacific": "CN",
|
||||
"latin america": "BR",
|
||||
"middle east": "SA",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Known commodity patterns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMMODITY_KEYWORDS: dict[str, str] = {
|
||||
"crude oil": "crude_oil",
|
||||
"oil": "crude_oil",
|
||||
"petroleum": "crude_oil",
|
||||
"natural gas": "natural_gas",
|
||||
"gas": "natural_gas",
|
||||
"copper": "copper",
|
||||
"steel": "steel",
|
||||
"lithium": "lithium",
|
||||
"semiconductor": "semiconductors",
|
||||
"semiconductors": "semiconductors",
|
||||
"chip": "semiconductors",
|
||||
"chips": "semiconductors",
|
||||
"wheat": "wheat",
|
||||
"corn": "corn",
|
||||
"gold": "gold",
|
||||
"aluminum": "aluminum",
|
||||
"aluminium": "aluminum",
|
||||
"nickel": "nickel",
|
||||
"cobalt": "cobalt",
|
||||
"rare earth": "rare_earth",
|
||||
}
|
||||
|
||||
# Minimum number of filing documents to consider inference meaningful
|
||||
_MIN_FILINGS_FOR_INFERENCE = 1
|
||||
|
||||
# Minimum total mentions to consider a region significant
|
||||
_MIN_REGION_MENTIONS = 1
|
||||
|
||||
# Minimum total mentions to consider a commodity significant
|
||||
_MIN_COMMODITY_MENTIONS = 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text scanning helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_regions_from_text(text: str) -> dict[str, int]:
|
||||
"""Extract region mentions from text, returning region_code -> count."""
|
||||
text_lower = text.lower()
|
||||
region_counts: dict[str, int] = defaultdict(int)
|
||||
|
||||
for keyword, code in _REGION_KEYWORDS.items():
|
||||
# Use word boundary matching for short keywords
|
||||
if len(keyword) <= 3:
|
||||
pattern = rf"\b{re.escape(keyword)}\b"
|
||||
matches = re.findall(pattern, text_lower)
|
||||
else:
|
||||
matches = re.findall(re.escape(keyword), text_lower)
|
||||
if matches:
|
||||
region_counts[code] += len(matches)
|
||||
|
||||
return dict(region_counts)
|
||||
|
||||
|
||||
def _extract_commodities_from_text(text: str) -> dict[str, int]:
|
||||
"""Extract commodity mentions from text, returning commodity_id -> count."""
|
||||
text_lower = text.lower()
|
||||
commodity_counts: dict[str, int] = defaultdict(int)
|
||||
|
||||
for keyword, commodity_id in _COMMODITY_KEYWORDS.items():
|
||||
if len(keyword) <= 4:
|
||||
pattern = rf"\b{re.escape(keyword)}\b"
|
||||
matches = re.findall(pattern, text_lower)
|
||||
else:
|
||||
matches = re.findall(re.escape(keyword), text_lower)
|
||||
if matches:
|
||||
commodity_counts[commodity_id] += len(matches)
|
||||
|
||||
return dict(commodity_counts)
|
||||
|
||||
|
||||
def _extract_supply_chain_regions(text: str) -> set[str]:
|
||||
"""Extract supply chain region mentions from text."""
|
||||
supply_keywords = [
|
||||
"supplier", "supply chain", "sourcing", "manufacturing",
|
||||
"factory", "plant", "warehouse", "distribution",
|
||||
"import", "export", "procurement",
|
||||
]
|
||||
text_lower = text.lower()
|
||||
|
||||
regions: set[str] = set()
|
||||
for keyword in supply_keywords:
|
||||
if keyword in text_lower:
|
||||
# Find regions mentioned near supply chain keywords
|
||||
# Look within a window around each occurrence
|
||||
for match in re.finditer(re.escape(keyword), text_lower):
|
||||
start = max(0, match.start() - 200)
|
||||
end = min(len(text_lower), match.end() + 200)
|
||||
window = text_lower[start:end]
|
||||
window_regions = _extract_regions_from_text(window)
|
||||
regions.update(window_regions.keys())
|
||||
|
||||
return regions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Revenue mix estimation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _estimate_revenue_mix(region_counts: dict[str, int]) -> dict[str, float]:
|
||||
"""Estimate geographic revenue mix from region mention counts.
|
||||
|
||||
Uses mention frequency as a proxy for revenue distribution.
|
||||
Normalizes to sum to 1.0.
|
||||
"""
|
||||
if not region_counts:
|
||||
return {}
|
||||
|
||||
total = sum(region_counts.values())
|
||||
if total == 0:
|
||||
return {}
|
||||
|
||||
mix = {
|
||||
region: round(count / total, 4)
|
||||
for region, count in region_counts.items()
|
||||
if count >= _MIN_REGION_MENTIONS
|
||||
}
|
||||
|
||||
# Re-normalize after filtering
|
||||
mix_total = sum(mix.values())
|
||||
if mix_total > 0 and abs(mix_total - 1.0) > 0.001:
|
||||
mix = {r: round(v / mix_total, 4) for r, v in mix.items()}
|
||||
|
||||
return mix
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Confidence scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_inference_confidence(
|
||||
num_filings: int,
|
||||
num_regions: int,
|
||||
num_commodities: int,
|
||||
total_mentions: int,
|
||||
) -> float:
|
||||
"""Compute confidence score for the inferred profile.
|
||||
|
||||
Higher confidence when more filings are available and more
|
||||
geographic/commodity data points are found.
|
||||
"""
|
||||
# Base confidence from number of filings (more filings = more reliable)
|
||||
filing_factor = min(num_filings / 5.0, 1.0) # saturates at 5 filings
|
||||
|
||||
# Data richness factor
|
||||
data_points = num_regions + num_commodities
|
||||
richness_factor = min(data_points / 8.0, 1.0) # saturates at 8 data points
|
||||
|
||||
# Mention volume factor
|
||||
volume_factor = min(total_mentions / 20.0, 1.0) # saturates at 20 mentions
|
||||
|
||||
confidence = 0.4 * filing_factor + 0.35 * richness_factor + 0.25 * volume_factor
|
||||
return round(max(0.0, min(1.0, confidence)), 4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main inference function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_exposure_profile(
|
||||
document_intelligences: list[DocumentIntelligence],
|
||||
sector: str,
|
||||
industry: str,
|
||||
market_cap_bucket: str,
|
||||
) -> ExposureProfileSchema:
|
||||
"""Infer a baseline exposure profile from filing extractions.
|
||||
|
||||
Scans recent filing extractions for geographic revenue breakdowns,
|
||||
supplier mentions, and commodity references. Produces an
|
||||
ExposureProfile with source='inferred' and a confidence score
|
||||
reflecting data quality.
|
||||
|
||||
Falls back to sector-based default profile when insufficient
|
||||
filing data is available.
|
||||
|
||||
Args:
|
||||
document_intelligences: List of DocumentIntelligence from recent filings.
|
||||
sector: Company's GICS sector name.
|
||||
industry: Company's industry name.
|
||||
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
|
||||
|
||||
Returns:
|
||||
An ExposureProfileSchema with source='inferred'.
|
||||
|
||||
Requirements: 9.1, 9.2, 9.3
|
||||
"""
|
||||
# Filter to filing-type documents
|
||||
filings = [
|
||||
di for di in document_intelligences
|
||||
if di.document_type.value in ("filing", "transcript")
|
||||
]
|
||||
|
||||
if len(filings) < _MIN_FILINGS_FOR_INFERENCE:
|
||||
logger.info(
|
||||
"Insufficient filing data (%d filings) for inference, "
|
||||
"falling back to sector-based default profile",
|
||||
len(filings),
|
||||
)
|
||||
return build_default_profile(sector, industry, market_cap_bucket)
|
||||
|
||||
# Aggregate region and commodity mentions across all filings
|
||||
all_region_counts: dict[str, int] = defaultdict(int)
|
||||
all_commodity_counts: dict[str, int] = defaultdict(int)
|
||||
all_supply_regions: set[str] = set()
|
||||
|
||||
for filing in filings:
|
||||
# Scan summary text
|
||||
if filing.summary:
|
||||
regions = _extract_regions_from_text(filing.summary)
|
||||
for r, c in regions.items():
|
||||
all_region_counts[r] += c
|
||||
|
||||
commodities = _extract_commodities_from_text(filing.summary)
|
||||
for com, c in commodities.items():
|
||||
all_commodity_counts[com] += c
|
||||
|
||||
supply_regions = _extract_supply_chain_regions(filing.summary)
|
||||
all_supply_regions.update(supply_regions)
|
||||
|
||||
# Scan company impacts for geographic and commodity mentions
|
||||
for company in filing.companies:
|
||||
# Key facts and evidence spans contain geographic details
|
||||
for text in company.key_facts + company.evidence_spans:
|
||||
regions = _extract_regions_from_text(text)
|
||||
for r, c in regions.items():
|
||||
all_region_counts[r] += c
|
||||
|
||||
commodities = _extract_commodities_from_text(text)
|
||||
for com, c in commodities.items():
|
||||
all_commodity_counts[com] += c
|
||||
|
||||
supply_regions = _extract_supply_chain_regions(text)
|
||||
all_supply_regions.update(supply_regions)
|
||||
|
||||
# Scan macro themes for commodity/region hints
|
||||
for theme in filing.macro_themes:
|
||||
regions = _extract_regions_from_text(theme)
|
||||
for r, c in regions.items():
|
||||
all_region_counts[r] += c
|
||||
|
||||
commodities = _extract_commodities_from_text(theme)
|
||||
for com, c in commodities.items():
|
||||
all_commodity_counts[com] += c
|
||||
|
||||
# Check if we have enough data to infer
|
||||
total_mentions = sum(all_region_counts.values()) + sum(all_commodity_counts.values())
|
||||
has_regions = len(all_region_counts) > 0
|
||||
has_commodities = len(all_commodity_counts) > 0
|
||||
|
||||
if not has_regions and not has_commodities:
|
||||
logger.info(
|
||||
"No geographic or commodity data found in %d filings, "
|
||||
"falling back to sector-based default profile",
|
||||
len(filings),
|
||||
)
|
||||
return build_default_profile(sector, industry, market_cap_bucket)
|
||||
|
||||
# Build the inferred profile
|
||||
geographic_revenue_mix = _estimate_revenue_mix(dict(all_region_counts))
|
||||
|
||||
# Filter commodities by minimum mentions
|
||||
key_commodities = [
|
||||
com for com, count in all_commodity_counts.items()
|
||||
if count >= _MIN_COMMODITY_MENTIONS
|
||||
]
|
||||
|
||||
# Supply chain regions: combine extracted supply regions with geo regions
|
||||
supply_chain_regions = list(all_supply_regions | set(geographic_revenue_mix.keys()))
|
||||
|
||||
# Market position tier from market cap bucket
|
||||
from services.aggregation.interpolation import _CAP_TO_TIER
|
||||
tier_value = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
|
||||
|
||||
# Regulatory jurisdictions: top regions by revenue
|
||||
sorted_regions = sorted(
|
||||
geographic_revenue_mix.items(), key=lambda x: x[1], reverse=True,
|
||||
)
|
||||
regulatory_jurisdictions = [r for r, _ in sorted_regions[:3]]
|
||||
|
||||
# Export dependency: fraction of revenue outside the top region
|
||||
if geographic_revenue_mix:
|
||||
top_region_pct = max(geographic_revenue_mix.values())
|
||||
export_pct = round(1.0 - top_region_pct, 4)
|
||||
else:
|
||||
export_pct = 0.0
|
||||
|
||||
# Confidence score
|
||||
confidence = _compute_inference_confidence(
|
||||
num_filings=len(filings),
|
||||
num_regions=len(all_region_counts),
|
||||
num_commodities=len(all_commodity_counts),
|
||||
total_mentions=total_mentions,
|
||||
)
|
||||
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="",
|
||||
geographic_revenue_mix=geographic_revenue_mix,
|
||||
supply_chain_regions=supply_chain_regions,
|
||||
key_input_commodities=key_commodities,
|
||||
regulatory_jurisdictions=regulatory_jurisdictions,
|
||||
market_position_tier=MarketPositionTier(tier_value),
|
||||
export_dependency_pct=max(0.0, min(1.0, export_pct)),
|
||||
source="inferred",
|
||||
confidence=confidence,
|
||||
version=1,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Inferred exposure profile: regions=%d, commodities=%d, "
|
||||
"supply_chain=%d, confidence=%.3f",
|
||||
len(geographic_revenue_mix),
|
||||
len(key_commodities),
|
||||
len(supply_chain_regions),
|
||||
confidence,
|
||||
)
|
||||
|
||||
return profile
|
||||
+234
-4
@@ -9,13 +9,21 @@ import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
from minio import Minio
|
||||
|
||||
from services.aggregation.interpolation import (
|
||||
build_default_profile,
|
||||
compute_macro_impact_with_sector,
|
||||
filter_low_confidence_events,
|
||||
persist_macro_impact_records,
|
||||
)
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.event_classifier import classify_global_event
|
||||
from services.extractor.worker import persist_extraction
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import inject_trace_context, setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_AGGREGATION,
|
||||
QUEUE_EXTRACTION,
|
||||
QUEUE_MACRO_CLASSIFICATION,
|
||||
queue_key,
|
||||
)
|
||||
|
||||
@@ -28,6 +36,198 @@ async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
|
||||
return {row["ticker"]: str(row["id"]) for row in rows}
|
||||
|
||||
|
||||
async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None:
|
||||
"""Fetch the document_type for a document."""
|
||||
row = await pool.fetchrow(
|
||||
"SELECT document_type FROM documents WHERE id = $1::uuid",
|
||||
document_id,
|
||||
)
|
||||
return row["document_type"] if row else None
|
||||
|
||||
|
||||
async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]:
|
||||
"""Fetch company info needed for exposure profile loading and interpolation."""
|
||||
rows = await pool.fetch(
|
||||
"""SELECT id, ticker, sector, industry, market_cap_bucket
|
||||
FROM companies WHERE active = TRUE"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str):
|
||||
"""Load exposure profile for a company: manual > inferred > default.
|
||||
|
||||
Requirements: 4.1
|
||||
"""
|
||||
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
|
||||
|
||||
# Try manual or inferred profile from DB
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version
|
||||
FROM exposure_profiles
|
||||
WHERE company_id = $1 AND active = TRUE
|
||||
ORDER BY version DESC LIMIT 1""",
|
||||
company_id,
|
||||
)
|
||||
if row:
|
||||
geo_mix = row["geographic_revenue_mix"]
|
||||
if isinstance(geo_mix, str):
|
||||
geo_mix = json.loads(geo_mix)
|
||||
tier_val = row["market_position_tier"]
|
||||
try:
|
||||
tier = MarketPositionTier(tier_val)
|
||||
except ValueError:
|
||||
tier = MarketPositionTier.REGIONAL
|
||||
return ExposureProfileSchema(
|
||||
company_id=str(row["company_id"]),
|
||||
geographic_revenue_mix=geo_mix or {},
|
||||
supply_chain_regions=list(row["supply_chain_regions"] or []),
|
||||
key_input_commodities=list(row["key_input_commodities"] or []),
|
||||
regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []),
|
||||
market_position_tier=tier,
|
||||
export_dependency_pct=float(row["export_dependency_pct"] or 0.0),
|
||||
source=row["source"] or "manual",
|
||||
confidence=float(row["confidence"] or 1.0),
|
||||
version=row["version"] or 1,
|
||||
)
|
||||
|
||||
# Fall back to default profile
|
||||
profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap")
|
||||
profile.company_id = str(company_id)
|
||||
return profile
|
||||
|
||||
|
||||
async def _compute_and_persist_macro_impacts(
|
||||
pool: asyncpg.Pool,
|
||||
event,
|
||||
companies: list[dict],
|
||||
confidence_threshold: float = 0.4,
|
||||
) -> list[str]:
|
||||
"""Compute MacroImpactRecords for all tracked companies and persist non-zero ones.
|
||||
|
||||
Requirements: 4.1, 4.5
|
||||
"""
|
||||
# Filter low-confidence events
|
||||
filtered = filter_low_confidence_events([event], confidence_threshold)
|
||||
if not filtered:
|
||||
logger.info("Event %s excluded: confidence %.3f below threshold %.3f",
|
||||
event.event_id, event.confidence, confidence_threshold)
|
||||
return []
|
||||
|
||||
records = []
|
||||
for company in companies:
|
||||
company_id = str(company["id"])
|
||||
ticker = company["ticker"]
|
||||
sector = company.get("sector") or ""
|
||||
industry = company.get("industry") or ""
|
||||
market_cap_bucket = company.get("market_cap_bucket") or "small_cap"
|
||||
|
||||
profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket)
|
||||
|
||||
record = compute_macro_impact_with_sector(event, profile, company_sector=sector)
|
||||
record.ticker = ticker
|
||||
record.company_id = company_id
|
||||
|
||||
if record.macro_impact_score > 0.0:
|
||||
records.append(record)
|
||||
|
||||
if records:
|
||||
ids = await persist_macro_impact_records(pool, records)
|
||||
logger.info(
|
||||
"Persisted %d macro impact records for event %s",
|
||||
len(ids), event.event_id,
|
||||
)
|
||||
return [r.ticker for r in records]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
# Track consecutive macro classification failures for alerting (Requirement 10.4)
|
||||
_macro_consecutive_failures = 0
|
||||
_MACRO_FAILURE_ALERT_THRESHOLD = 3
|
||||
|
||||
|
||||
async def _process_macro_classification(
|
||||
*,
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
ollama: OllamaClient,
|
||||
redis_client: aioredis.Redis,
|
||||
document_id: str,
|
||||
text: str,
|
||||
company_id_map: dict[str, str],
|
||||
confidence_threshold: float = 0.4,
|
||||
) -> None:
|
||||
"""Route a macro_event document to event classification, compute interpolation,
|
||||
and trigger aggregation for affected tickers.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4
|
||||
"""
|
||||
global _macro_consecutive_failures
|
||||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||||
|
||||
try:
|
||||
event = await classify_global_event(
|
||||
normalized_text=text,
|
||||
document_id=document_id,
|
||||
ollama_client=ollama,
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
)
|
||||
logger.info(
|
||||
"Classified macro event %s for doc %s: severity=%s types=%s",
|
||||
event.event_id, document_id, event.severity, event.event_types,
|
||||
)
|
||||
|
||||
# Reset failure counter on success
|
||||
_macro_consecutive_failures = 0
|
||||
|
||||
# Load all tracked companies and compute macro impacts
|
||||
companies = await _fetch_company_info(pool)
|
||||
affected_tickers = await _compute_and_persist_macro_impacts(
|
||||
pool, event, companies, confidence_threshold,
|
||||
)
|
||||
|
||||
# Trigger aggregation for affected tickers (those with non-zero impact)
|
||||
enqueued_tickers = set()
|
||||
for ticker in affected_tickers:
|
||||
if ticker not in enqueued_tickers:
|
||||
await redis_client.rpush(
|
||||
agg_queue,
|
||||
json.dumps(inject_trace_context({
|
||||
"ticker": ticker,
|
||||
"macro_event_id": event.event_id,
|
||||
})),
|
||||
)
|
||||
enqueued_tickers.add(ticker)
|
||||
|
||||
logger.info(
|
||||
"Enqueued aggregation jobs for %d affected tickers after macro event %s",
|
||||
len(enqueued_tickers), event.event_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
_macro_consecutive_failures += 1
|
||||
logger.error("Macro event classification failed for doc %s: %s", document_id, e)
|
||||
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
|
||||
logger.critical(
|
||||
"ALERT: Sustained macro classification failures (%d consecutive). "
|
||||
"Continuing with company-only signals. Operator action required.",
|
||||
_macro_consecutive_failures,
|
||||
)
|
||||
except Exception:
|
||||
_macro_consecutive_failures += 1
|
||||
logger.exception("Unexpected error classifying macro event for doc %s", document_id)
|
||||
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
|
||||
logger.critical(
|
||||
"ALERT: Sustained macro classification failures (%d consecutive). "
|
||||
"Continuing with company-only signals. Operator action required.",
|
||||
_macro_consecutive_failures,
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
|
||||
@@ -42,8 +242,10 @@ async def main() -> None:
|
||||
ollama = OllamaClient(config.ollama)
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_EXTRACTION)
|
||||
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||||
logger.info("Extractor worker started, polling %s", queue)
|
||||
confidence_threshold = config.macro.macro_confidence_threshold
|
||||
logger.info("Extractor worker started, polling %s and %s", queue, macro_queue)
|
||||
|
||||
# Pre-load company ID map (refreshed periodically)
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
@@ -51,7 +253,13 @@ async def main() -> None:
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await redis_client.lpop(queue)
|
||||
# Check macro classification queue first (priority)
|
||||
raw = await redis_client.lpop(macro_queue)
|
||||
is_macro_job = raw is not None
|
||||
|
||||
if raw is None:
|
||||
raw = await redis_client.lpop(queue)
|
||||
|
||||
if raw is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
@@ -80,13 +288,35 @@ async def main() -> None:
|
||||
except Exception as e:
|
||||
logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e)
|
||||
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
# Refresh company map every 100 jobs
|
||||
refresh_counter += 1
|
||||
if refresh_counter % 100 == 0:
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
|
||||
# Route macro_event documents to event classification (Requirement 2.1)
|
||||
doc_type = None
|
||||
if is_macro_job:
|
||||
doc_type = "macro_event"
|
||||
else:
|
||||
doc_type = await _fetch_document_type(pool, document_id)
|
||||
|
||||
if doc_type == "macro_event":
|
||||
logger.info("Routing macro_event doc %s to event classifier", document_id)
|
||||
await _process_macro_classification(
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
ollama=ollama,
|
||||
redis_client=redis_client,
|
||||
document_id=document_id,
|
||||
text=text,
|
||||
company_id_map=company_id_map,
|
||||
confidence_threshold=confidence_threshold,
|
||||
)
|
||||
continue
|
||||
|
||||
# Standard extraction pipeline for non-macro documents
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
try:
|
||||
# Pass all tracked tickers so the model can identify any mentioned companies
|
||||
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
||||
|
||||
Reference in New Issue
Block a user