Files
stonks-oracle/services/aggregation/interpolation.py
T

742 lines
24 KiB
Python

"""Interpolation engine — macro-to-company impact scoring.
Computes per-company macro impact scores by evaluating overlap between
global event classifications and company exposure profiles. Produces
MacroImpactRecord objects that feed into the aggregation engine as
additional weighted signals.
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 3.2
"""
from __future__ import annotations
import json
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.extractor.event_classifier import GlobalEvent
from services.shared.schemas import (
ExposureProfileSchema,
MarketPositionTier,
SeverityLevel,
)
logger = logging.getLogger("interpolation")
# ---------------------------------------------------------------------------
# Default configuration constants
# ---------------------------------------------------------------------------
DEFAULT_CONFIDENCE_THRESHOLD = 0.4
DEFAULT_SHORT_TERM_STALENESS_HOURS = 48
ACCELERATED_DECAY_MULTIPLIER = 0.5 # applied on top of standard recency decay
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Severity weights
SEVERITY_WEIGHTS: dict[str, float] = {
SeverityLevel.CRITICAL.value: 1.0,
SeverityLevel.HIGH.value: 0.75,
SeverityLevel.MODERATE.value: 0.5,
SeverityLevel.LOW.value: 0.25,
}
# Component weights in the scoring formula
GEO_WEIGHT = 0.35
SUPPLY_WEIGHT = 0.25
COMMODITY_WEIGHT = 0.25
SECTOR_WEIGHT = 0.15
# Resilience modifiers for international events
RESILIENCE_MODIFIERS: dict[str, float] = {
MarketPositionTier.GLOBAL_LEADER.value: 0.7,
MarketPositionTier.MULTINATIONAL.value: 0.85,
MarketPositionTier.REGIONAL.value: 1.0,
MarketPositionTier.DOMESTIC.value: 1.2,
}
# Event types that are typically negative
_NEGATIVE_EVENT_TYPES = frozenset({
"supply_disruption",
"cost_increase",
"regulatory_pressure",
"geopolitical_risk",
"trade_barrier",
})
# Event types that can be positive
_POSITIVE_EVENT_TYPES = frozenset({
"demand_shift",
})
# Event types that can go either way
_AMBIGUOUS_EVENT_TYPES = frozenset({
"commodity_shock",
"currency_impact",
})
# Market cap bucket → market position tier mapping for default profiles
_CAP_TO_TIER: dict[str, str] = {
"large_cap": MarketPositionTier.GLOBAL_LEADER.value,
"mid_cap": MarketPositionTier.MULTINATIONAL.value,
"small_cap": MarketPositionTier.REGIONAL.value,
"micro_cap": MarketPositionTier.DOMESTIC.value,
}
# Sector-based default geographic revenue mixes
_SECTOR_DEFAULT_GEO: dict[str, dict[str, float]] = {
"Information Technology": {"US": 0.45, "CN": 0.15, "EU": 0.15, "JP": 0.10, "KR": 0.15},
"Health Care": {"US": 0.50, "EU": 0.25, "JP": 0.10, "CN": 0.15},
"Financials": {"US": 0.55, "EU": 0.20, "GB": 0.15, "JP": 0.10},
"Energy": {"US": 0.30, "SA": 0.20, "RU": 0.15, "CA": 0.15, "AE": 0.20},
"Materials": {"US": 0.25, "CN": 0.25, "AU": 0.20, "BR": 0.15, "IN": 0.15},
"Industrials": {"US": 0.40, "DE": 0.15, "CN": 0.15, "JP": 0.15, "KR": 0.15},
"Consumer Discretionary": {"US": 0.45, "CN": 0.20, "EU": 0.15, "JP": 0.10, "IN": 0.10},
"Consumer Staples": {"US": 0.45, "EU": 0.20, "CN": 0.15, "IN": 0.10, "BR": 0.10},
"Communication Services": {"US": 0.50, "CN": 0.15, "EU": 0.15, "JP": 0.10, "IN": 0.10},
"Utilities": {"US": 0.70, "EU": 0.15, "JP": 0.15},
"Real Estate": {"US": 0.60, "CN": 0.15, "EU": 0.15, "JP": 0.10},
}
_DEFAULT_GEO = {"US": 0.50, "EU": 0.20, "CN": 0.15, "JP": 0.15}
# ---------------------------------------------------------------------------
# MacroImpactRecord dataclass
# ---------------------------------------------------------------------------
@dataclass
class MacroImpactRecord:
"""A computed macro impact score for a specific company-event pair."""
event_id: str = ""
company_id: str = ""
ticker: str = ""
macro_impact_score: float = 0.0 # [0, 1]
impact_direction: str = "neutral" # positive|negative|mixed
contributing_factors: list[str] = field(default_factory=list)
confidence: float = 0.5 # [0, 1]
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
# ---------------------------------------------------------------------------
# Overlap computation functions
# ---------------------------------------------------------------------------
def compute_geographic_overlap(
event_regions: list[str],
revenue_mix: dict[str, float],
) -> float:
"""Compute geographic overlap using revenue percentage weighting.
For each event region that appears in the company's revenue mix,
sum the revenue percentage. Returns a value in [0, 1].
Args:
event_regions: Region codes from the global event.
revenue_mix: Company's geographic_revenue_mix (region -> pct).
Returns:
Sum of revenue percentages for overlapping regions, clamped to [0, 1].
"""
if not event_regions or not revenue_mix:
return 0.0
event_set = {r.upper() for r in event_regions}
overlap = 0.0
for region, pct in revenue_mix.items():
if region.upper() in event_set:
overlap += pct
return min(max(overlap, 0.0), 1.0)
def compute_supply_chain_overlap(
event_regions: list[str],
supply_regions: list[str],
) -> float:
"""Compute supply chain overlap using set intersection ratio.
Returns the fraction of the company's supply chain regions that
overlap with the event's affected regions.
Args:
event_regions: Region codes from the global event.
supply_regions: Company's supply_chain_regions.
Returns:
Intersection ratio in [0, 1]. 0.0 if supply_regions is empty.
"""
if not event_regions or not supply_regions:
return 0.0
event_set = {r.upper() for r in event_regions}
supply_set = {r.upper() for r in supply_regions}
intersection = event_set & supply_set
return len(intersection) / len(supply_set)
def compute_commodity_overlap(
event_commodities: list[str],
company_commodities: list[str],
) -> float:
"""Compute commodity overlap using set intersection ratio.
Returns the fraction of the company's key commodities that overlap
with the event's affected commodities.
Args:
event_commodities: Commodity identifiers from the global event.
company_commodities: Company's key_input_commodities.
Returns:
Intersection ratio in [0, 1]. 0.0 if company_commodities is empty.
"""
if not event_commodities or not company_commodities:
return 0.0
event_set = {c.lower() for c in event_commodities}
company_set = {c.lower() for c in company_commodities}
intersection = event_set & company_set
return len(intersection) / len(company_set)
# ---------------------------------------------------------------------------
# Resilience modifier
# ---------------------------------------------------------------------------
def apply_resilience_modifier(
raw_score: float,
tier: str,
event_is_international: bool = True,
) -> float:
"""Apply a resilience modifier based on market position tier.
For international events, global leaders get a dampening factor (0.7)
while domestic companies get an amplification factor (1.2).
For domestic-only events, no modifier is applied.
Args:
raw_score: The raw impact score before resilience adjustment.
tier: Market position tier value.
event_is_international: Whether the event affects multiple countries.
Returns:
Modified score clamped to [0, 1].
"""
if not event_is_international:
return min(max(raw_score, 0.0), 1.0)
modifier = RESILIENCE_MODIFIERS.get(tier, 1.0)
return min(max(raw_score * modifier, 0.0), 1.0)
# ---------------------------------------------------------------------------
# Impact direction determination
# ---------------------------------------------------------------------------
def _determine_impact_direction(
event_types: list[str],
) -> tuple[str, list[str], list[str]]:
"""Determine impact direction from event types.
Returns:
Tuple of (direction, positive_factors, negative_factors).
"""
positive_factors: list[str] = []
negative_factors: list[str] = []
for et in event_types:
if et in _NEGATIVE_EVENT_TYPES:
negative_factors.append(et)
elif et in _POSITIVE_EVENT_TYPES:
positive_factors.append(et)
elif et in _AMBIGUOUS_EVENT_TYPES:
# Ambiguous types contribute to both sides
positive_factors.append(et)
negative_factors.append(et)
has_positive = len(positive_factors) > 0
has_negative = len(negative_factors) > 0
if has_positive and has_negative:
return "mixed", positive_factors, negative_factors
elif has_positive:
return "positive", positive_factors, negative_factors
elif has_negative:
return "negative", positive_factors, negative_factors
else:
return "negative", positive_factors, negative_factors
# ---------------------------------------------------------------------------
# Core scoring function
# ---------------------------------------------------------------------------
def compute_macro_impact(
event: GlobalEvent,
profile: ExposureProfileSchema,
) -> MacroImpactRecord:
"""Compute the macro impact of a global event on a company.
Scoring formula:
raw_score = severity_weight * (
0.35 * geographic_overlap +
0.25 * supply_chain_overlap +
0.25 * commodity_overlap +
0.15 * sector_match
)
final_score = apply_resilience_modifier(raw_score, tier, is_international)
Args:
event: The classified global event.
profile: The company's exposure profile.
Returns:
A MacroImpactRecord with the computed score and metadata.
"""
now = datetime.now(timezone.utc)
# Compute overlaps
geo_overlap = compute_geographic_overlap(
event.affected_regions,
profile.geographic_revenue_mix,
)
supply_overlap = compute_supply_chain_overlap(
event.affected_regions,
profile.supply_chain_regions,
)
commodity_overlap = compute_commodity_overlap(
event.affected_commodities,
profile.key_input_commodities,
)
# Sector match: 1.0 if any event sector matches the company's sector
# We check against the profile's regulatory_jurisdictions as a proxy,
# but the real sector comes from the company data. For now, we use
# a simple heuristic: check if any affected_sectors appear in the
# profile's geographic_revenue_mix keys or supply_chain_regions.
# The actual sector is not stored in ExposureProfileSchema, so we
# check if any event sectors match. This will be 0.0 unless the
# caller provides sector info through contributing_factors.
sector_match = 0.0
# We'll compute sector_match based on event sectors — the caller
# should ensure the profile has relevant sector info. For the
# default implementation, we always set sector_match to 0.0 here
# and let the caller override if needed.
# Check zero-overlap case
contributing = []
if geo_overlap > 0:
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
if supply_overlap > 0:
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
if commodity_overlap > 0:
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
if total_overlap == 0.0:
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=0.0,
impact_direction="neutral",
contributing_factors=[],
confidence=0.0,
computed_at=now,
)
# Severity weight
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
# Raw score
raw_score = severity_weight * (
GEO_WEIGHT * geo_overlap
+ SUPPLY_WEIGHT * supply_overlap
+ COMMODITY_WEIGHT * commodity_overlap
+ SECTOR_WEIGHT * sector_match
)
# Determine if event is international (affects multiple regions)
is_international = len(event.affected_regions) > 1
# Apply resilience modifier
tier = profile.market_position_tier
if isinstance(tier, MarketPositionTier):
tier = tier.value
final_score = apply_resilience_modifier(raw_score, tier, is_international)
# Determine impact direction
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
# Build contributing factors list
all_factors = list(contributing)
if pos_factors:
all_factors.append(f"positive_types:{','.join(pos_factors)}")
if neg_factors:
all_factors.append(f"negative_types:{','.join(neg_factors)}")
# Confidence: combine event confidence with overlap strength
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=round(min(final_score, 1.0), 6),
impact_direction=direction,
contributing_factors=all_factors,
confidence=round(confidence, 6),
computed_at=now,
)
def compute_macro_impact_with_sector(
event: GlobalEvent,
profile: ExposureProfileSchema,
company_sector: str = "",
) -> MacroImpactRecord:
"""Compute macro impact with explicit sector matching.
Like compute_macro_impact but accepts a company_sector parameter
for proper sector_match computation.
Args:
event: The classified global event.
profile: The company's exposure profile.
company_sector: The company's GICS sector name.
Returns:
A MacroImpactRecord with the computed score and metadata.
"""
now = datetime.now(timezone.utc)
# Compute overlaps
geo_overlap = compute_geographic_overlap(
event.affected_regions,
profile.geographic_revenue_mix,
)
supply_overlap = compute_supply_chain_overlap(
event.affected_regions,
profile.supply_chain_regions,
)
commodity_overlap = compute_commodity_overlap(
event.affected_commodities,
profile.key_input_commodities,
)
# Sector match
sector_match = 0.0
if company_sector and event.affected_sectors:
company_sector_lower = company_sector.lower().strip()
for es in event.affected_sectors:
if es.lower().strip() == company_sector_lower:
sector_match = 1.0
break
# Contributing factors
contributing: list[str] = []
if geo_overlap > 0:
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
if supply_overlap > 0:
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
if commodity_overlap > 0:
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
if sector_match > 0:
contributing.append(f"sector_match:{company_sector}")
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
if total_overlap == 0.0:
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=0.0,
impact_direction="neutral",
contributing_factors=[],
confidence=0.0,
computed_at=now,
)
# Severity weight
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
# Raw score
raw_score = severity_weight * (
GEO_WEIGHT * geo_overlap
+ SUPPLY_WEIGHT * supply_overlap
+ COMMODITY_WEIGHT * commodity_overlap
+ SECTOR_WEIGHT * sector_match
)
# International check
is_international = len(event.affected_regions) > 1
# Resilience modifier
tier = profile.market_position_tier
if isinstance(tier, MarketPositionTier):
tier = tier.value
final_score = apply_resilience_modifier(raw_score, tier, is_international)
# Direction
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
all_factors = list(contributing)
if pos_factors:
all_factors.append(f"positive_types:{','.join(pos_factors)}")
if neg_factors:
all_factors.append(f"negative_types:{','.join(neg_factors)}")
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=round(min(final_score, 1.0), 6),
impact_direction=direction,
contributing_factors=all_factors,
confidence=round(confidence, 6),
computed_at=now,
)
# ---------------------------------------------------------------------------
# Default profile builder
# ---------------------------------------------------------------------------
def build_default_profile(
sector: str,
industry: str,
market_cap_bucket: str,
) -> ExposureProfileSchema:
"""Build a default ExposureProfile for companies without manual profiles.
Uses sector-based geographic revenue defaults and maps market_cap_bucket
to market_position_tier:
large_cap → global_leader
mid_cap → multinational
small_cap → regional
micro_cap → domestic
Args:
sector: GICS sector name.
industry: Industry name (used for commodity defaults).
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
Returns:
An ExposureProfileSchema with source='inferred'.
"""
tier = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
geo_mix = _SECTOR_DEFAULT_GEO.get(sector, _DEFAULT_GEO)
# Derive supply chain regions from geo mix keys
supply_regions = list(geo_mix.keys())
# Derive commodities from sector/industry
commodities = _infer_commodities(sector, industry)
# Export dependency based on tier
export_pct = {
MarketPositionTier.GLOBAL_LEADER.value: 0.5,
MarketPositionTier.MULTINATIONAL.value: 0.35,
MarketPositionTier.REGIONAL.value: 0.15,
MarketPositionTier.DOMESTIC.value: 0.05,
}.get(tier, 0.15)
return ExposureProfileSchema(
company_id="",
geographic_revenue_mix=dict(geo_mix),
supply_chain_regions=supply_regions,
key_input_commodities=commodities,
regulatory_jurisdictions=list(geo_mix.keys())[:3],
market_position_tier=MarketPositionTier(tier),
export_dependency_pct=export_pct,
source="inferred",
confidence=0.5,
version=1,
)
def _infer_commodities(sector: str, industry: str) -> list[str]:
"""Infer key input commodities from sector and industry."""
sector_commodities: dict[str, list[str]] = {
"Energy": ["crude_oil", "natural_gas"],
"Materials": ["copper", "steel", "lithium"],
"Industrials": ["steel", "copper"],
"Information Technology": ["semiconductors", "lithium"],
"Consumer Staples": ["wheat", "corn"],
"Consumer Discretionary": ["steel", "semiconductors"],
"Health Care": [],
"Financials": [],
"Communication Services": [],
"Utilities": ["natural_gas"],
"Real Estate": ["steel"],
}
return sector_commodities.get(sector, [])
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
async def persist_macro_impact_record(
pool: asyncpg.Pool,
record: MacroImpactRecord,
) -> str:
"""Persist a MacroImpactRecord to the macro_impact_records table.
Returns the row UUID.
"""
row_id = await pool.fetchval(
"""INSERT INTO macro_impact_records
(event_id, company_id, ticker, macro_impact_score,
impact_direction, contributing_factors, confidence, computed_at)
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6::jsonb, $7, $8)
RETURNING id""",
record.event_id,
record.company_id,
record.ticker,
record.macro_impact_score,
record.impact_direction,
json.dumps(record.contributing_factors),
record.confidence,
record.computed_at,
)
logger.info(
"Persisted macro impact record for event=%s company=%s score=%.4f direction=%s",
record.event_id,
record.company_id,
record.macro_impact_score,
record.impact_direction,
)
return str(row_id)
async def persist_macro_impact_records(
pool: asyncpg.Pool,
records: list[MacroImpactRecord],
) -> list[str]:
"""Persist multiple MacroImpactRecords. Returns list of row UUIDs."""
ids: list[str] = []
for record in records:
if record.macro_impact_score > 0.0:
row_id = await persist_macro_impact_record(pool, record)
ids.append(row_id)
return ids
# ---------------------------------------------------------------------------
# Low-confidence event exclusion (Requirements: 10.1)
# ---------------------------------------------------------------------------
def filter_low_confidence_events(
events: list[GlobalEvent],
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
) -> list[GlobalEvent]:
"""Filter out events with confidence below the configurable threshold.
Events with confidence below the threshold are excluded from macro
impact computation and the exclusion reason is logged.
Args:
events: List of GlobalEvent classifications.
confidence_threshold: Minimum confidence for inclusion (default 0.4).
Returns:
List of events that pass the confidence threshold.
Requirements: 10.1
"""
included: list[GlobalEvent] = []
for event in events:
if event.confidence < confidence_threshold:
logger.info(
"Excluding low-confidence event %s: confidence=%.3f < threshold=%.3f",
event.event_id,
event.confidence,
confidence_threshold,
)
else:
included.append(event)
return included
# ---------------------------------------------------------------------------
# Accelerated decay for stale short-term events (Requirements: 10.2)
# ---------------------------------------------------------------------------
def compute_standard_recency_decay(
age_hours: float,
half_life_hours: float = 168.0,
) -> float:
"""Compute standard exponential recency decay.
Args:
age_hours: Age of the event in hours.
half_life_hours: Half-life for the decay function (default 7 days).
Returns:
Decay factor in (0, 1].
"""
import math
if age_hours <= 0:
return 1.0
return math.exp(-0.693 * age_hours / half_life_hours)
def apply_accelerated_decay(
age_hours: float,
estimated_duration: str,
staleness_hours: float = DEFAULT_SHORT_TERM_STALENESS_HOURS,
half_life_hours: float = 168.0,
) -> float:
"""Apply accelerated decay for stale short-term events.
For short_term events older than staleness_hours (default 48h),
the effective weight is strictly less than standard recency decay.
For non-short-term events or events within the staleness window,
standard recency decay is applied.
Args:
age_hours: Age of the event in hours.
estimated_duration: Event's estimated_duration field.
staleness_hours: Hours after which short-term events get accelerated decay.
half_life_hours: Half-life for the standard decay function.
Returns:
Effective signal weight in (0, 1].
Requirements: 10.2
"""
standard_decay = compute_standard_recency_decay(age_hours, half_life_hours)
if estimated_duration == "short_term" and age_hours > staleness_hours:
# Apply accelerated decay: multiply standard decay by a factor < 1
accelerated = standard_decay * ACCELERATED_DECAY_MULTIPLIER
logger.debug(
"Accelerated decay for short_term event: age=%.1fh, "
"standard=%.4f, accelerated=%.4f",
age_hours, standard_decay, accelerated,
)
return accelerated
return standard_decay