feat: competitive intelligence & historical pattern matching layer

This commit is contained in:
Celes Renata
2026-04-14 19:42:48 +00:00
parent b478022ba3
commit f7a11d14ea
203 changed files with 20155 additions and 97 deletions
+549
View File
@@ -0,0 +1,549 @@
"""Event classifier module for macro news articles.
Classifies global/geopolitical news articles into structured GlobalEvent
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
existing OllamaClient for inference and retry logic.
Persists classification prompts, raw outputs, and final events to MinIO
and PostgreSQL for audit and downstream interpolation.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
import asyncpg
from minio import Minio
from services.shared.schemas import (
EstimatedDuration,
ImpactType,
ModelMetadata,
SeverityLevel,
)
from services.shared.storage import upload_artifact
logger = logging.getLogger("event_classifier")
PROMPT_VERSION = "event-classification-v1"
SCHEMA_VERSION = "1.0.0"
# Valid enum value sets for normalization
_VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType)
_VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel)
_VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration)
# ---------------------------------------------------------------------------
# GlobalEvent dataclass
# ---------------------------------------------------------------------------
@dataclass
class GlobalEvent:
"""Structured classification of a macro news event.
Produced by the event classifier from Ollama structured output.
"""
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
event_types: list[str] = field(default_factory=list)
severity: str = "low"
affected_regions: list[str] = field(default_factory=list)
affected_sectors: list[str] = field(default_factory=list)
affected_commodities: list[str] = field(default_factory=list)
summary: str = ""
key_facts: list[str] = field(default_factory=list)
estimated_duration: str = "short_term"
confidence: float = 0.5
source_document_id: str = ""
model_metadata: ModelMetadata = field(default_factory=ModelMetadata)
# ---------------------------------------------------------------------------
# JSON schema for Ollama structured output
# ---------------------------------------------------------------------------
class _EventClassificationResult:
"""Schema definition for the Ollama event classification response.
Not a Pydantic model — we build the JSON schema dict directly to keep
it self-contained and Ollama-friendly (no $refs).
"""
pass
def get_event_json_schema() -> dict[str, Any]:
"""Return the JSON schema for Ollama structured event classification output.
The schema forces the model to produce all required fields explicitly.
"""
return {
"type": "object",
"required": [
"event_types",
"severity",
"affected_regions",
"affected_sectors",
"affected_commodities",
"summary",
"key_facts",
"estimated_duration",
"confidence",
],
"properties": {
"event_types": {
"type": "array",
"items": {
"type": "string",
"enum": sorted(_VALID_IMPACT_TYPES),
},
"description": (
"One or more impact types this event represents. "
"Include ALL applicable types — do not collapse to a single category."
),
},
"severity": {
"type": "string",
"enum": sorted(_VALID_SEVERITY_LEVELS),
"description": "Overall severity of the event: low, moderate, high, or critical.",
},
"affected_regions": {
"type": "array",
"items": {"type": "string"},
"description": (
"ISO 3166-1 alpha-2 country codes or region names affected. "
"Use standard codes like US, CN, EU, GB, JP. "
"Only include regions explicitly mentioned or clearly implied."
),
},
"affected_sectors": {
"type": "array",
"items": {"type": "string"},
"description": (
"GICS sector identifiers or sector names affected. "
"Examples: Energy, Materials, Industrials, Consumer Discretionary, "
"Consumer Staples, Health Care, Financials, Information Technology, "
"Communication Services, Utilities, Real Estate."
),
},
"affected_commodities": {
"type": "array",
"items": {"type": "string"},
"description": (
"Commodity identifiers affected, if applicable. "
"Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, "
"semiconductors. Empty list if no commodities are directly affected."
),
},
"summary": {
"type": "string",
"description": "A concise 1-3 sentence summary of the event and its market implications.",
},
"key_facts": {
"type": "array",
"items": {"type": "string"},
"description": (
"Key facts explicitly stated in the article. "
"Do NOT infer, speculate, or fabricate facts. "
"Each fact must be directly supported by the text."
),
},
"estimated_duration": {
"type": "string",
"enum": sorted(_VALID_DURATIONS),
"description": (
"Expected duration of market impact: "
"short_term (days to weeks), medium_term (weeks to months), "
"long_term (months to years)."
),
},
"confidence": {
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
"description": (
"Your confidence in this classification. "
"Lower if the article is ambiguous, speculative, or lacks concrete details."
),
},
},
"additionalProperties": False,
}
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
_SYSTEM_PROMPT = """\
You classify global news articles into structured macro event intelligence. \
Read the article carefully and extract the event classification. \
Return ONLY valid JSON matching the schema. No commentary, no markdown, no explanation."""
_ANTI_HALLUCINATION_RULES = """\
CRITICAL RULES — read carefully:
1. Only extract information EXPLICITLY stated in the article text.
2. Do NOT infer, speculate, or fabricate facts, regions, sectors, or commodities.
3. If the article mentions multiple distinct impact types, include ALL of them in event_types.
4. For affected_regions, only include regions explicitly mentioned or clearly implied by the event.
5. For affected_sectors, only include sectors with a clear causal link to the event.
6. For affected_commodities, only include commodities directly referenced or obviously impacted.
7. For key_facts, each fact must be directly supported by a specific passage in the text.
8. If the article is vague or speculative, set confidence LOW (below 0.4).
9. Do NOT treat journalist speculation or opinion as confirmed fact.
10. Distinguish between announced policy and proposed/rumored policy."""
def build_event_classification_prompt(text: str) -> dict[str, str]:
"""Build system and user prompts for Ollama event classification.
Args:
text: Normalized text content of the macro news article.
Returns:
Dict with 'system' and 'user' prompt strings.
"""
user_prompt = f"""\
Classify this global news article as a macro event. Fill every field.
{_ANTI_HALLUCINATION_RULES}
Classify the event by:
- event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \
regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk)
- severity: low, moderate, high, or critical
- affected_regions: ISO country codes or region names
- affected_sectors: GICS sector names
- affected_commodities: commodity identifiers (empty list if none)
- summary: 1-3 sentence summary of the event and market implications
- key_facts: facts explicitly stated in the text (NO fabrication)
- estimated_duration: short_term, medium_term, or long_term
- confidence: 0.0-1.0 your confidence in this classification
--- ARTICLE TEXT ---
{text}
--- END ARTICLE TEXT ---"""
return {
"system": _SYSTEM_PROMPT,
"user": user_prompt,
}
# ---------------------------------------------------------------------------
# Classification response parsing and normalization
# ---------------------------------------------------------------------------
def _normalize_event_types(raw: list[Any]) -> list[str]:
"""Normalize and filter event_types to valid ImpactType values."""
result = []
for item in raw:
val = str(item).lower().strip()
if val in _VALID_IMPACT_TYPES:
result.append(val)
return result if result else ["geopolitical_risk"]
def _normalize_severity(raw: str) -> str:
"""Normalize severity to a valid SeverityLevel value."""
val = str(raw).lower().strip()
return val if val in _VALID_SEVERITY_LEVELS else "low"
def _normalize_duration(raw: str) -> str:
"""Normalize estimated_duration to a valid EstimatedDuration value."""
val = str(raw).lower().strip()
return val if val in _VALID_DURATIONS else "short_term"
def _parse_classification_response(
raw_json: str,
document_id: str,
model_name: str,
) -> GlobalEvent:
"""Parse raw Ollama JSON output into a GlobalEvent.
Normalizes enum values and clamps numeric fields.
"""
data = json.loads(raw_json)
confidence = data.get("confidence", 0.5)
if isinstance(confidence, (int, float)):
confidence = max(0.0, min(1.0, float(confidence)))
else:
confidence = 0.5
return GlobalEvent(
event_id=str(uuid.uuid4()),
event_types=_normalize_event_types(data.get("event_types", [])),
severity=_normalize_severity(data.get("severity", "low")),
affected_regions=[str(r) for r in data.get("affected_regions", [])],
affected_sectors=[str(s) for s in data.get("affected_sectors", [])],
affected_commodities=[str(c) for c in data.get("affected_commodities", [])],
summary=str(data.get("summary", "")),
key_facts=[str(f) for f in data.get("key_facts", [])],
estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")),
confidence=confidence,
source_document_id=document_id,
model_metadata=ModelMetadata(
provider="ollama",
model_name=model_name,
prompt_version=PROMPT_VERSION,
schema_version=SCHEMA_VERSION,
),
)
# ---------------------------------------------------------------------------
# MinIO persistence helpers
# ---------------------------------------------------------------------------
def _upload_classification_prompt(
minio_client: Minio,
document_id: str,
prompt_data: dict[str, str],
model_name: str,
timestamp: datetime | None = None,
) -> str:
"""Upload classification prompt and metadata to stonks-llm-prompts."""
ts = timestamp or datetime.now(timezone.utc)
payload = json.dumps({
"prompt_version": PROMPT_VERSION,
"schema_version": SCHEMA_VERSION,
"model": model_name,
"system_prompt": prompt_data["system"],
"user_prompt": prompt_data["user"],
"json_schema": get_event_json_schema(),
}, indent=2).encode()
path = (
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/prompt.json"
)
return upload_artifact(
minio_client, "stonks-llm-prompts", path, payload,
content_type="application/json",
)
def _upload_classification_result(
minio_client: Minio,
document_id: str,
raw_output: str,
event: GlobalEvent | None,
success: bool,
error: str | None,
timestamp: datetime | None = None,
) -> str:
"""Upload raw classification output to stonks-llm-results."""
ts = timestamp or datetime.now(timezone.utc)
payload = json.dumps({
"document_id": document_id,
"success": success,
"error": error,
"raw_output": raw_output,
"parsed_event": {
"event_id": event.event_id,
"event_types": event.event_types,
"severity": event.severity,
"affected_regions": event.affected_regions,
"affected_sectors": event.affected_sectors,
"affected_commodities": event.affected_commodities,
"summary": event.summary,
"key_facts": event.key_facts,
"estimated_duration": event.estimated_duration,
"confidence": event.confidence,
} if event else None,
}, indent=2).encode()
path = (
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/result.json"
)
return upload_artifact(
minio_client, "stonks-llm-results", path, payload,
content_type="application/json",
)
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
async def persist_global_event(
pool: asyncpg.Pool,
event: GlobalEvent,
) -> str:
"""Persist a GlobalEvent record to the global_events PostgreSQL table.
Returns the event row UUID.
"""
row_id = await pool.fetchval(
"""INSERT INTO global_events
(id, event_types, severity, affected_regions, affected_sectors,
affected_commodities, summary, key_facts, estimated_duration,
confidence, source_document_id, model_provider, model_name,
prompt_version, schema_version)
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10,
$11::uuid, $12, $13, $14, $15)
RETURNING id""",
event.event_id,
event.event_types,
event.severity,
event.affected_regions,
event.affected_sectors,
event.affected_commodities,
event.summary,
json.dumps(event.key_facts),
event.estimated_duration,
event.confidence,
event.source_document_id,
event.model_metadata.provider,
event.model_metadata.model_name,
event.model_metadata.prompt_version,
event.model_metadata.schema_version,
)
logger.info(
"Persisted global event %s for doc %s (severity=%s, types=%s)",
row_id, event.source_document_id, event.severity, event.event_types,
)
return str(row_id)
# ---------------------------------------------------------------------------
# Main classification function
# ---------------------------------------------------------------------------
async def classify_global_event(
normalized_text: str,
document_id: str,
ollama_client: Any,
*,
pool: asyncpg.Pool | None = None,
minio_client: Minio | None = None,
) -> GlobalEvent:
"""Classify a macro news article into a GlobalEvent using Ollama.
Uses the existing OllamaClient's streaming infrastructure with a
dedicated event classification prompt and JSON schema. Follows the
same retry policy as document extraction.
Persists prompt, raw output, and final event to MinIO and PostgreSQL
when the respective clients are provided.
Args:
normalized_text: Cleaned text content of the macro article.
document_id: UUID of the source document.
ollama_client: An OllamaClient instance (from services.extractor.client).
pool: Optional asyncpg pool for PostgreSQL persistence.
minio_client: Optional MinIO client for artifact persistence.
Returns:
A GlobalEvent with the classification result.
Raises:
ValueError: If classification fails after all retries.
"""
ts = datetime.now(timezone.utc)
prompts = build_event_classification_prompt(normalized_text)
json_schema = get_event_json_schema()
model_name = ollama_client._config.model
# Persist prompt to MinIO
prompt_ref = None
if minio_client:
try:
prompt_ref = _upload_classification_prompt(
minio_client, document_id, prompts, model_name, timestamp=ts,
)
except Exception:
logger.exception("Failed to upload classification prompt for doc %s", document_id)
# Call Ollama using the client's internal _call_ollama method
# We reuse the retry logic pattern from OllamaClient.extract()
max_retries = ollama_client._max_retries
last_error: str | None = None
raw_output = ""
for attempt_num in range(max_retries + 1):
attempt = await ollama_client._call_ollama(prompts, json_schema)
raw_output = attempt.raw_output
if attempt.error is None and raw_output:
# Try to parse the response
try:
event = _parse_classification_response(
raw_output, document_id, model_name,
)
# Persist result to MinIO
if minio_client:
try:
_upload_classification_result(
minio_client, document_id, raw_output,
event, success=True, error=None, timestamp=ts,
)
except Exception:
logger.exception(
"Failed to upload classification result for doc %s", document_id,
)
# Persist to PostgreSQL
if pool:
try:
await persist_global_event(pool, event)
except Exception:
logger.exception(
"Failed to persist global event for doc %s", document_id,
)
return event
except (json.JSONDecodeError, KeyError, TypeError) as exc:
last_error = f"parse_error: {exc}"
logger.warning(
"Classification parse error for doc %s attempt %d: %s",
document_id, attempt_num + 1, exc,
)
else:
last_error = attempt.error or "empty_response"
# Retry with backoff
if attempt_num < max_retries:
delay = ollama_client._base_delay * (
ollama_client._backoff_multiplier ** attempt_num
)
delay = min(delay, ollama_client._max_delay)
logger.warning(
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
)
await asyncio.sleep(delay)
# All retries exhausted — persist failure and raise
if minio_client:
try:
_upload_classification_result(
minio_client, document_id, raw_output,
event=None, success=False, error=last_error, timestamp=ts,
)
except Exception:
logger.exception(
"Failed to upload failed classification result for doc %s", document_id,
)
raise ValueError(
f"Event classification failed for document {document_id} "
f"after {max_retries + 1} attempts: {last_error}"
)
+394
View File
@@ -0,0 +1,394 @@
"""Exposure profile auto-inference from filing extractions.
Infers baseline exposure profiles from company filing extractions when
no manual profile exists. Scans recent filing extractions for geographic
revenue breakdowns, supplier mentions, and commodity references.
Requirements: 9.1, 9.2, 9.3
"""
from __future__ import annotations
import logging
import re
from collections import defaultdict
from services.aggregation.interpolation import build_default_profile
from services.shared.schemas import (
DocumentIntelligence,
ExposureProfileSchema,
MarketPositionTier,
)
logger = logging.getLogger("exposure_inference")
# ---------------------------------------------------------------------------
# Known region patterns for geographic extraction
# ---------------------------------------------------------------------------
_REGION_KEYWORDS: dict[str, str] = {
"united states": "US",
"u.s.": "US",
"us": "US",
"america": "US",
"north america": "US",
"china": "CN",
"chinese": "CN",
"europe": "EU",
"european": "EU",
"eu": "EU",
"japan": "JP",
"japanese": "JP",
"germany": "DE",
"german": "DE",
"united kingdom": "GB",
"uk": "GB",
"britain": "GB",
"british": "GB",
"south korea": "KR",
"korea": "KR",
"india": "IN",
"indian": "IN",
"brazil": "BR",
"brazilian": "BR",
"australia": "AU",
"australian": "AU",
"canada": "CA",
"canadian": "CA",
"taiwan": "TW",
"saudi arabia": "SA",
"russia": "RU",
"russian": "RU",
"mexico": "MX",
"singapore": "SG",
"asia": "CN",
"asia pacific": "CN",
"latin america": "BR",
"middle east": "SA",
}
# ---------------------------------------------------------------------------
# Known commodity patterns
# ---------------------------------------------------------------------------
_COMMODITY_KEYWORDS: dict[str, str] = {
"crude oil": "crude_oil",
"oil": "crude_oil",
"petroleum": "crude_oil",
"natural gas": "natural_gas",
"gas": "natural_gas",
"copper": "copper",
"steel": "steel",
"lithium": "lithium",
"semiconductor": "semiconductors",
"semiconductors": "semiconductors",
"chip": "semiconductors",
"chips": "semiconductors",
"wheat": "wheat",
"corn": "corn",
"gold": "gold",
"aluminum": "aluminum",
"aluminium": "aluminum",
"nickel": "nickel",
"cobalt": "cobalt",
"rare earth": "rare_earth",
}
# Minimum number of filing documents to consider inference meaningful
_MIN_FILINGS_FOR_INFERENCE = 1
# Minimum total mentions to consider a region significant
_MIN_REGION_MENTIONS = 1
# Minimum total mentions to consider a commodity significant
_MIN_COMMODITY_MENTIONS = 1
# ---------------------------------------------------------------------------
# Text scanning helpers
# ---------------------------------------------------------------------------
def _extract_regions_from_text(text: str) -> dict[str, int]:
"""Extract region mentions from text, returning region_code -> count."""
text_lower = text.lower()
region_counts: dict[str, int] = defaultdict(int)
for keyword, code in _REGION_KEYWORDS.items():
# Use word boundary matching for short keywords
if len(keyword) <= 3:
pattern = rf"\b{re.escape(keyword)}\b"
matches = re.findall(pattern, text_lower)
else:
matches = re.findall(re.escape(keyword), text_lower)
if matches:
region_counts[code] += len(matches)
return dict(region_counts)
def _extract_commodities_from_text(text: str) -> dict[str, int]:
"""Extract commodity mentions from text, returning commodity_id -> count."""
text_lower = text.lower()
commodity_counts: dict[str, int] = defaultdict(int)
for keyword, commodity_id in _COMMODITY_KEYWORDS.items():
if len(keyword) <= 4:
pattern = rf"\b{re.escape(keyword)}\b"
matches = re.findall(pattern, text_lower)
else:
matches = re.findall(re.escape(keyword), text_lower)
if matches:
commodity_counts[commodity_id] += len(matches)
return dict(commodity_counts)
def _extract_supply_chain_regions(text: str) -> set[str]:
"""Extract supply chain region mentions from text."""
supply_keywords = [
"supplier", "supply chain", "sourcing", "manufacturing",
"factory", "plant", "warehouse", "distribution",
"import", "export", "procurement",
]
text_lower = text.lower()
regions: set[str] = set()
for keyword in supply_keywords:
if keyword in text_lower:
# Find regions mentioned near supply chain keywords
# Look within a window around each occurrence
for match in re.finditer(re.escape(keyword), text_lower):
start = max(0, match.start() - 200)
end = min(len(text_lower), match.end() + 200)
window = text_lower[start:end]
window_regions = _extract_regions_from_text(window)
regions.update(window_regions.keys())
return regions
# ---------------------------------------------------------------------------
# Revenue mix estimation
# ---------------------------------------------------------------------------
def _estimate_revenue_mix(region_counts: dict[str, int]) -> dict[str, float]:
"""Estimate geographic revenue mix from region mention counts.
Uses mention frequency as a proxy for revenue distribution.
Normalizes to sum to 1.0.
"""
if not region_counts:
return {}
total = sum(region_counts.values())
if total == 0:
return {}
mix = {
region: round(count / total, 4)
for region, count in region_counts.items()
if count >= _MIN_REGION_MENTIONS
}
# Re-normalize after filtering
mix_total = sum(mix.values())
if mix_total > 0 and abs(mix_total - 1.0) > 0.001:
mix = {r: round(v / mix_total, 4) for r, v in mix.items()}
return mix
# ---------------------------------------------------------------------------
# Confidence scoring
# ---------------------------------------------------------------------------
def _compute_inference_confidence(
num_filings: int,
num_regions: int,
num_commodities: int,
total_mentions: int,
) -> float:
"""Compute confidence score for the inferred profile.
Higher confidence when more filings are available and more
geographic/commodity data points are found.
"""
# Base confidence from number of filings (more filings = more reliable)
filing_factor = min(num_filings / 5.0, 1.0) # saturates at 5 filings
# Data richness factor
data_points = num_regions + num_commodities
richness_factor = min(data_points / 8.0, 1.0) # saturates at 8 data points
# Mention volume factor
volume_factor = min(total_mentions / 20.0, 1.0) # saturates at 20 mentions
confidence = 0.4 * filing_factor + 0.35 * richness_factor + 0.25 * volume_factor
return round(max(0.0, min(1.0, confidence)), 4)
# ---------------------------------------------------------------------------
# Main inference function
# ---------------------------------------------------------------------------
def infer_exposure_profile(
document_intelligences: list[DocumentIntelligence],
sector: str,
industry: str,
market_cap_bucket: str,
) -> ExposureProfileSchema:
"""Infer a baseline exposure profile from filing extractions.
Scans recent filing extractions for geographic revenue breakdowns,
supplier mentions, and commodity references. Produces an
ExposureProfile with source='inferred' and a confidence score
reflecting data quality.
Falls back to sector-based default profile when insufficient
filing data is available.
Args:
document_intelligences: List of DocumentIntelligence from recent filings.
sector: Company's GICS sector name.
industry: Company's industry name.
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
Returns:
An ExposureProfileSchema with source='inferred'.
Requirements: 9.1, 9.2, 9.3
"""
# Filter to filing-type documents
filings = [
di for di in document_intelligences
if di.document_type.value in ("filing", "transcript")
]
if len(filings) < _MIN_FILINGS_FOR_INFERENCE:
logger.info(
"Insufficient filing data (%d filings) for inference, "
"falling back to sector-based default profile",
len(filings),
)
return build_default_profile(sector, industry, market_cap_bucket)
# Aggregate region and commodity mentions across all filings
all_region_counts: dict[str, int] = defaultdict(int)
all_commodity_counts: dict[str, int] = defaultdict(int)
all_supply_regions: set[str] = set()
for filing in filings:
# Scan summary text
if filing.summary:
regions = _extract_regions_from_text(filing.summary)
for r, c in regions.items():
all_region_counts[r] += c
commodities = _extract_commodities_from_text(filing.summary)
for com, c in commodities.items():
all_commodity_counts[com] += c
supply_regions = _extract_supply_chain_regions(filing.summary)
all_supply_regions.update(supply_regions)
# Scan company impacts for geographic and commodity mentions
for company in filing.companies:
# Key facts and evidence spans contain geographic details
for text in company.key_facts + company.evidence_spans:
regions = _extract_regions_from_text(text)
for r, c in regions.items():
all_region_counts[r] += c
commodities = _extract_commodities_from_text(text)
for com, c in commodities.items():
all_commodity_counts[com] += c
supply_regions = _extract_supply_chain_regions(text)
all_supply_regions.update(supply_regions)
# Scan macro themes for commodity/region hints
for theme in filing.macro_themes:
regions = _extract_regions_from_text(theme)
for r, c in regions.items():
all_region_counts[r] += c
commodities = _extract_commodities_from_text(theme)
for com, c in commodities.items():
all_commodity_counts[com] += c
# Check if we have enough data to infer
total_mentions = sum(all_region_counts.values()) + sum(all_commodity_counts.values())
has_regions = len(all_region_counts) > 0
has_commodities = len(all_commodity_counts) > 0
if not has_regions and not has_commodities:
logger.info(
"No geographic or commodity data found in %d filings, "
"falling back to sector-based default profile",
len(filings),
)
return build_default_profile(sector, industry, market_cap_bucket)
# Build the inferred profile
geographic_revenue_mix = _estimate_revenue_mix(dict(all_region_counts))
# Filter commodities by minimum mentions
key_commodities = [
com for com, count in all_commodity_counts.items()
if count >= _MIN_COMMODITY_MENTIONS
]
# Supply chain regions: combine extracted supply regions with geo regions
supply_chain_regions = list(all_supply_regions | set(geographic_revenue_mix.keys()))
# Market position tier from market cap bucket
from services.aggregation.interpolation import _CAP_TO_TIER
tier_value = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
# Regulatory jurisdictions: top regions by revenue
sorted_regions = sorted(
geographic_revenue_mix.items(), key=lambda x: x[1], reverse=True,
)
regulatory_jurisdictions = [r for r, _ in sorted_regions[:3]]
# Export dependency: fraction of revenue outside the top region
if geographic_revenue_mix:
top_region_pct = max(geographic_revenue_mix.values())
export_pct = round(1.0 - top_region_pct, 4)
else:
export_pct = 0.0
# Confidence score
confidence = _compute_inference_confidence(
num_filings=len(filings),
num_regions=len(all_region_counts),
num_commodities=len(all_commodity_counts),
total_mentions=total_mentions,
)
profile = ExposureProfileSchema(
company_id="",
geographic_revenue_mix=geographic_revenue_mix,
supply_chain_regions=supply_chain_regions,
key_input_commodities=key_commodities,
regulatory_jurisdictions=regulatory_jurisdictions,
market_position_tier=MarketPositionTier(tier_value),
export_dependency_pct=max(0.0, min(1.0, export_pct)),
source="inferred",
confidence=confidence,
version=1,
)
logger.info(
"Inferred exposure profile: regions=%d, commodities=%d, "
"supply_chain=%d, confidence=%.3f",
len(geographic_revenue_mix),
len(key_commodities),
len(supply_chain_regions),
confidence,
)
return profile
+234 -4
View File
@@ -9,13 +9,21 @@ import asyncpg
import redis.asyncio as aioredis
from minio import Minio
from services.aggregation.interpolation import (
build_default_profile,
compute_macro_impact_with_sector,
filter_low_confidence_events,
persist_macro_impact_records,
)
from services.extractor.client import OllamaClient
from services.extractor.event_classifier import classify_global_event
from services.extractor.worker import persist_extraction
from services.shared.config import load_config
from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import (
QUEUE_AGGREGATION,
QUEUE_EXTRACTION,
QUEUE_MACRO_CLASSIFICATION,
queue_key,
)
@@ -28,6 +36,198 @@ async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
return {row["ticker"]: str(row["id"]) for row in rows}
async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None:
"""Fetch the document_type for a document."""
row = await pool.fetchrow(
"SELECT document_type FROM documents WHERE id = $1::uuid",
document_id,
)
return row["document_type"] if row else None
async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]:
"""Fetch company info needed for exposure profile loading and interpolation."""
rows = await pool.fetch(
"""SELECT id, ticker, sector, industry, market_cap_bucket
FROM companies WHERE active = TRUE"""
)
return [dict(r) for r in rows]
async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str):
"""Load exposure profile for a company: manual > inferred > default.
Requirements: 4.1
"""
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
# Try manual or inferred profile from DB
row = await pool.fetchrow(
"""SELECT company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version
FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC LIMIT 1""",
company_id,
)
if row:
geo_mix = row["geographic_revenue_mix"]
if isinstance(geo_mix, str):
geo_mix = json.loads(geo_mix)
tier_val = row["market_position_tier"]
try:
tier = MarketPositionTier(tier_val)
except ValueError:
tier = MarketPositionTier.REGIONAL
return ExposureProfileSchema(
company_id=str(row["company_id"]),
geographic_revenue_mix=geo_mix or {},
supply_chain_regions=list(row["supply_chain_regions"] or []),
key_input_commodities=list(row["key_input_commodities"] or []),
regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []),
market_position_tier=tier,
export_dependency_pct=float(row["export_dependency_pct"] or 0.0),
source=row["source"] or "manual",
confidence=float(row["confidence"] or 1.0),
version=row["version"] or 1,
)
# Fall back to default profile
profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap")
profile.company_id = str(company_id)
return profile
async def _compute_and_persist_macro_impacts(
pool: asyncpg.Pool,
event,
companies: list[dict],
confidence_threshold: float = 0.4,
) -> list[str]:
"""Compute MacroImpactRecords for all tracked companies and persist non-zero ones.
Requirements: 4.1, 4.5
"""
# Filter low-confidence events
filtered = filter_low_confidence_events([event], confidence_threshold)
if not filtered:
logger.info("Event %s excluded: confidence %.3f below threshold %.3f",
event.event_id, event.confidence, confidence_threshold)
return []
records = []
for company in companies:
company_id = str(company["id"])
ticker = company["ticker"]
sector = company.get("sector") or ""
industry = company.get("industry") or ""
market_cap_bucket = company.get("market_cap_bucket") or "small_cap"
profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket)
record = compute_macro_impact_with_sector(event, profile, company_sector=sector)
record.ticker = ticker
record.company_id = company_id
if record.macro_impact_score > 0.0:
records.append(record)
if records:
ids = await persist_macro_impact_records(pool, records)
logger.info(
"Persisted %d macro impact records for event %s",
len(ids), event.event_id,
)
return [r.ticker for r in records]
return []
# Track consecutive macro classification failures for alerting (Requirement 10.4)
_macro_consecutive_failures = 0
_MACRO_FAILURE_ALERT_THRESHOLD = 3
async def _process_macro_classification(
*,
pool: asyncpg.Pool,
minio_client: Minio,
ollama: OllamaClient,
redis_client: aioredis.Redis,
document_id: str,
text: str,
company_id_map: dict[str, str],
confidence_threshold: float = 0.4,
) -> None:
"""Route a macro_event document to event classification, compute interpolation,
and trigger aggregation for affected tickers.
Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4
"""
global _macro_consecutive_failures
agg_queue = queue_key(QUEUE_AGGREGATION)
try:
event = await classify_global_event(
normalized_text=text,
document_id=document_id,
ollama_client=ollama,
pool=pool,
minio_client=minio_client,
)
logger.info(
"Classified macro event %s for doc %s: severity=%s types=%s",
event.event_id, document_id, event.severity, event.event_types,
)
# Reset failure counter on success
_macro_consecutive_failures = 0
# Load all tracked companies and compute macro impacts
companies = await _fetch_company_info(pool)
affected_tickers = await _compute_and_persist_macro_impacts(
pool, event, companies, confidence_threshold,
)
# Trigger aggregation for affected tickers (those with non-zero impact)
enqueued_tickers = set()
for ticker in affected_tickers:
if ticker not in enqueued_tickers:
await redis_client.rpush(
agg_queue,
json.dumps(inject_trace_context({
"ticker": ticker,
"macro_event_id": event.event_id,
})),
)
enqueued_tickers.add(ticker)
logger.info(
"Enqueued aggregation jobs for %d affected tickers after macro event %s",
len(enqueued_tickers), event.event_id,
)
except ValueError as e:
_macro_consecutive_failures += 1
logger.error("Macro event classification failed for doc %s: %s", document_id, e)
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
logger.critical(
"ALERT: Sustained macro classification failures (%d consecutive). "
"Continuing with company-only signals. Operator action required.",
_macro_consecutive_failures,
)
except Exception:
_macro_consecutive_failures += 1
logger.exception("Unexpected error classifying macro event for doc %s", document_id)
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
logger.critical(
"ALERT: Sustained macro classification failures (%d consecutive). "
"Continuing with company-only signals. Operator action required.",
_macro_consecutive_failures,
)
async def main() -> None:
config = load_config()
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
@@ -42,8 +242,10 @@ async def main() -> None:
ollama = OllamaClient(config.ollama)
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION)
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
agg_queue = queue_key(QUEUE_AGGREGATION)
logger.info("Extractor worker started, polling %s", queue)
confidence_threshold = config.macro.macro_confidence_threshold
logger.info("Extractor worker started, polling %s and %s", queue, macro_queue)
# Pre-load company ID map (refreshed periodically)
company_id_map = await _build_company_id_map(pool)
@@ -51,7 +253,13 @@ async def main() -> None:
try:
while True:
raw = await redis_client.lpop(queue)
# Check macro classification queue first (priority)
raw = await redis_client.lpop(macro_queue)
is_macro_job = raw is not None
if raw is None:
raw = await redis_client.lpop(queue)
if raw is None:
await asyncio.sleep(1)
continue
@@ -80,13 +288,35 @@ async def main() -> None:
except Exception as e:
logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e)
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
# Refresh company map every 100 jobs
refresh_counter += 1
if refresh_counter % 100 == 0:
company_id_map = await _build_company_id_map(pool)
# Route macro_event documents to event classification (Requirement 2.1)
doc_type = None
if is_macro_job:
doc_type = "macro_event"
else:
doc_type = await _fetch_document_type(pool, document_id)
if doc_type == "macro_event":
logger.info("Routing macro_event doc %s to event classifier", document_id)
await _process_macro_classification(
pool=pool,
minio_client=minio_client,
ollama=ollama,
redis_client=redis_client,
document_id=document_id,
text=text,
company_id_map=company_id_map,
confidence_threshold=confidence_threshold,
)
continue
# Standard extraction pipeline for non-macro documents
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
try:
# Pass all tracked tickers so the model can identify any mentioned companies
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)