"""Interpolation engine — macro-to-company impact scoring. Computes per-company macro impact scores by evaluating overlap between global event classifications and company exposure profiles. Produces MacroImpactRecord objects that feed into the aggregation engine as additional weighted signals. Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 3.2 """ from __future__ import annotations import json import logging from dataclasses import dataclass, field from datetime import datetime, timezone import asyncpg from services.extractor.event_classifier import GlobalEvent from services.shared.schemas import ( ExposureProfileSchema, MarketPositionTier, SeverityLevel, ) logger = logging.getLogger("interpolation") # --------------------------------------------------------------------------- # Default configuration constants # --------------------------------------------------------------------------- DEFAULT_CONFIDENCE_THRESHOLD = 0.4 DEFAULT_SHORT_TERM_STALENESS_HOURS = 48 ACCELERATED_DECAY_MULTIPLIER = 0.5 # applied on top of standard recency decay # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- # Severity weights SEVERITY_WEIGHTS: dict[str, float] = { SeverityLevel.CRITICAL.value: 1.0, SeverityLevel.HIGH.value: 0.75, SeverityLevel.MODERATE.value: 0.5, SeverityLevel.LOW.value: 0.25, } # Component weights in the scoring formula GEO_WEIGHT = 0.35 SUPPLY_WEIGHT = 0.25 COMMODITY_WEIGHT = 0.25 SECTOR_WEIGHT = 0.15 # Resilience modifiers for international events RESILIENCE_MODIFIERS: dict[str, float] = { MarketPositionTier.GLOBAL_LEADER.value: 0.7, MarketPositionTier.MULTINATIONAL.value: 0.85, MarketPositionTier.REGIONAL.value: 1.0, MarketPositionTier.DOMESTIC.value: 1.2, } # Event types that are typically negative _NEGATIVE_EVENT_TYPES = frozenset({ "supply_disruption", "cost_increase", "regulatory_pressure", "geopolitical_risk", "trade_barrier", }) # Event types that can be positive _POSITIVE_EVENT_TYPES = frozenset({ "demand_shift", }) # Event types that can go either way _AMBIGUOUS_EVENT_TYPES = frozenset({ "commodity_shock", "currency_impact", }) # Market cap bucket → market position tier mapping for default profiles _CAP_TO_TIER: dict[str, str] = { "large_cap": MarketPositionTier.GLOBAL_LEADER.value, "mid_cap": MarketPositionTier.MULTINATIONAL.value, "small_cap": MarketPositionTier.REGIONAL.value, "micro_cap": MarketPositionTier.DOMESTIC.value, } # Sector-based default geographic revenue mixes _SECTOR_DEFAULT_GEO: dict[str, dict[str, float]] = { "Information Technology": {"US": 0.45, "CN": 0.15, "EU": 0.15, "JP": 0.10, "KR": 0.15}, "Health Care": {"US": 0.50, "EU": 0.25, "JP": 0.10, "CN": 0.15}, "Financials": {"US": 0.55, "EU": 0.20, "GB": 0.15, "JP": 0.10}, "Energy": {"US": 0.30, "SA": 0.20, "RU": 0.15, "CA": 0.15, "AE": 0.20}, "Materials": {"US": 0.25, "CN": 0.25, "AU": 0.20, "BR": 0.15, "IN": 0.15}, "Industrials": {"US": 0.40, "DE": 0.15, "CN": 0.15, "JP": 0.15, "KR": 0.15}, "Consumer Discretionary": {"US": 0.45, "CN": 0.20, "EU": 0.15, "JP": 0.10, "IN": 0.10}, "Consumer Staples": {"US": 0.45, "EU": 0.20, "CN": 0.15, "IN": 0.10, "BR": 0.10}, "Communication Services": {"US": 0.50, "CN": 0.15, "EU": 0.15, "JP": 0.10, "IN": 0.10}, "Utilities": {"US": 0.70, "EU": 0.15, "JP": 0.15}, "Real Estate": {"US": 0.60, "CN": 0.15, "EU": 0.15, "JP": 0.10}, } _DEFAULT_GEO = {"US": 0.50, "EU": 0.20, "CN": 0.15, "JP": 0.15} # --------------------------------------------------------------------------- # MacroImpactRecord dataclass # --------------------------------------------------------------------------- @dataclass class MacroImpactRecord: """A computed macro impact score for a specific company-event pair.""" event_id: str = "" company_id: str = "" ticker: str = "" macro_impact_score: float = 0.0 # [0, 1] impact_direction: str = "neutral" # positive|negative|mixed contributing_factors: list[str] = field(default_factory=list) confidence: float = 0.5 # [0, 1] computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) # --------------------------------------------------------------------------- # Overlap computation functions # --------------------------------------------------------------------------- def compute_geographic_overlap( event_regions: list[str], revenue_mix: dict[str, float], ) -> float: """Compute geographic overlap using revenue percentage weighting. For each event region that appears in the company's revenue mix, sum the revenue percentage. Returns a value in [0, 1]. Args: event_regions: Region codes from the global event. revenue_mix: Company's geographic_revenue_mix (region -> pct). Returns: Sum of revenue percentages for overlapping regions, clamped to [0, 1]. """ if not event_regions or not revenue_mix: return 0.0 event_set = {r.upper() for r in event_regions} overlap = 0.0 for region, pct in revenue_mix.items(): if region.upper() in event_set: overlap += pct return min(max(overlap, 0.0), 1.0) def compute_supply_chain_overlap( event_regions: list[str], supply_regions: list[str], ) -> float: """Compute supply chain overlap using set intersection ratio. Returns the fraction of the company's supply chain regions that overlap with the event's affected regions. Args: event_regions: Region codes from the global event. supply_regions: Company's supply_chain_regions. Returns: Intersection ratio in [0, 1]. 0.0 if supply_regions is empty. """ if not event_regions or not supply_regions: return 0.0 event_set = {r.upper() for r in event_regions} supply_set = {r.upper() for r in supply_regions} intersection = event_set & supply_set return len(intersection) / len(supply_set) def compute_commodity_overlap( event_commodities: list[str], company_commodities: list[str], ) -> float: """Compute commodity overlap using set intersection ratio. Returns the fraction of the company's key commodities that overlap with the event's affected commodities. Args: event_commodities: Commodity identifiers from the global event. company_commodities: Company's key_input_commodities. Returns: Intersection ratio in [0, 1]. 0.0 if company_commodities is empty. """ if not event_commodities or not company_commodities: return 0.0 event_set = {c.lower() for c in event_commodities} company_set = {c.lower() for c in company_commodities} intersection = event_set & company_set return len(intersection) / len(company_set) # --------------------------------------------------------------------------- # Resilience modifier # --------------------------------------------------------------------------- def apply_resilience_modifier( raw_score: float, tier: str, event_is_international: bool = True, ) -> float: """Apply a resilience modifier based on market position tier. For international events, global leaders get a dampening factor (0.7) while domestic companies get an amplification factor (1.2). For domestic-only events, no modifier is applied. Args: raw_score: The raw impact score before resilience adjustment. tier: Market position tier value. event_is_international: Whether the event affects multiple countries. Returns: Modified score clamped to [0, 1]. """ if not event_is_international: return min(max(raw_score, 0.0), 1.0) modifier = RESILIENCE_MODIFIERS.get(tier, 1.0) return min(max(raw_score * modifier, 0.0), 1.0) # --------------------------------------------------------------------------- # Impact direction determination # --------------------------------------------------------------------------- def _determine_impact_direction( event_types: list[str], ) -> tuple[str, list[str], list[str]]: """Determine impact direction from event types. Returns: Tuple of (direction, positive_factors, negative_factors). """ positive_factors: list[str] = [] negative_factors: list[str] = [] for et in event_types: if et in _NEGATIVE_EVENT_TYPES: negative_factors.append(et) elif et in _POSITIVE_EVENT_TYPES: positive_factors.append(et) elif et in _AMBIGUOUS_EVENT_TYPES: # Ambiguous types contribute to both sides positive_factors.append(et) negative_factors.append(et) has_positive = len(positive_factors) > 0 has_negative = len(negative_factors) > 0 if has_positive and has_negative: return "mixed", positive_factors, negative_factors elif has_positive: return "positive", positive_factors, negative_factors elif has_negative: return "negative", positive_factors, negative_factors else: return "negative", positive_factors, negative_factors # --------------------------------------------------------------------------- # Core scoring function # --------------------------------------------------------------------------- def compute_macro_impact( event: GlobalEvent, profile: ExposureProfileSchema, ) -> MacroImpactRecord: """Compute the macro impact of a global event on a company. Scoring formula: raw_score = severity_weight * ( 0.35 * geographic_overlap + 0.25 * supply_chain_overlap + 0.25 * commodity_overlap + 0.15 * sector_match ) final_score = apply_resilience_modifier(raw_score, tier, is_international) Args: event: The classified global event. profile: The company's exposure profile. Returns: A MacroImpactRecord with the computed score and metadata. """ now = datetime.now(timezone.utc) # Compute overlaps geo_overlap = compute_geographic_overlap( event.affected_regions, profile.geographic_revenue_mix, ) supply_overlap = compute_supply_chain_overlap( event.affected_regions, profile.supply_chain_regions, ) commodity_overlap = compute_commodity_overlap( event.affected_commodities, profile.key_input_commodities, ) # Sector match: 1.0 if any event sector matches the company's sector # We check against the profile's regulatory_jurisdictions as a proxy, # but the real sector comes from the company data. For now, we use # a simple heuristic: check if any affected_sectors appear in the # profile's geographic_revenue_mix keys or supply_chain_regions. # The actual sector is not stored in ExposureProfileSchema, so we # check if any event sectors match. This will be 0.0 unless the # caller provides sector info through contributing_factors. sector_match = 0.0 # We'll compute sector_match based on event sectors — the caller # should ensure the profile has relevant sector info. For the # default implementation, we always set sector_match to 0.0 here # and let the caller override if needed. # Check zero-overlap case contributing = [] if geo_overlap > 0: contributing.append(f"geographic_overlap:{geo_overlap:.3f}") if supply_overlap > 0: contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}") if commodity_overlap > 0: contributing.append(f"commodity_overlap:{commodity_overlap:.3f}") total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match if total_overlap == 0.0: return MacroImpactRecord( event_id=event.event_id, company_id=profile.company_id, ticker="", macro_impact_score=0.0, impact_direction="neutral", contributing_factors=[], confidence=0.0, computed_at=now, ) # Severity weight severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25) # Raw score raw_score = severity_weight * ( GEO_WEIGHT * geo_overlap + SUPPLY_WEIGHT * supply_overlap + COMMODITY_WEIGHT * commodity_overlap + SECTOR_WEIGHT * sector_match ) # Determine if event is international (affects multiple regions) is_international = len(event.affected_regions) > 1 # Apply resilience modifier tier = profile.market_position_tier if isinstance(tier, MarketPositionTier): tier = tier.value final_score = apply_resilience_modifier(raw_score, tier, is_international) # Determine impact direction direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types) # Build contributing factors list all_factors = list(contributing) if pos_factors: all_factors.append(f"positive_types:{','.join(pos_factors)}") if neg_factors: all_factors.append(f"negative_types:{','.join(neg_factors)}") # Confidence: combine event confidence with overlap strength confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0) return MacroImpactRecord( event_id=event.event_id, company_id=profile.company_id, ticker="", macro_impact_score=round(min(final_score, 1.0), 6), impact_direction=direction, contributing_factors=all_factors, confidence=round(confidence, 6), computed_at=now, ) def compute_macro_impact_with_sector( event: GlobalEvent, profile: ExposureProfileSchema, company_sector: str = "", ) -> MacroImpactRecord: """Compute macro impact with explicit sector matching. Like compute_macro_impact but accepts a company_sector parameter for proper sector_match computation. Args: event: The classified global event. profile: The company's exposure profile. company_sector: The company's GICS sector name. Returns: A MacroImpactRecord with the computed score and metadata. """ now = datetime.now(timezone.utc) # Compute overlaps geo_overlap = compute_geographic_overlap( event.affected_regions, profile.geographic_revenue_mix, ) supply_overlap = compute_supply_chain_overlap( event.affected_regions, profile.supply_chain_regions, ) commodity_overlap = compute_commodity_overlap( event.affected_commodities, profile.key_input_commodities, ) # Sector match sector_match = 0.0 if company_sector and event.affected_sectors: company_sector_lower = company_sector.lower().strip() for es in event.affected_sectors: if es.lower().strip() == company_sector_lower: sector_match = 1.0 break # Contributing factors contributing: list[str] = [] if geo_overlap > 0: contributing.append(f"geographic_overlap:{geo_overlap:.3f}") if supply_overlap > 0: contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}") if commodity_overlap > 0: contributing.append(f"commodity_overlap:{commodity_overlap:.3f}") if sector_match > 0: contributing.append(f"sector_match:{company_sector}") total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match if total_overlap == 0.0: return MacroImpactRecord( event_id=event.event_id, company_id=profile.company_id, ticker="", macro_impact_score=0.0, impact_direction="neutral", contributing_factors=[], confidence=0.0, computed_at=now, ) # Severity weight severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25) # Raw score raw_score = severity_weight * ( GEO_WEIGHT * geo_overlap + SUPPLY_WEIGHT * supply_overlap + COMMODITY_WEIGHT * commodity_overlap + SECTOR_WEIGHT * sector_match ) # International check is_international = len(event.affected_regions) > 1 # Resilience modifier tier = profile.market_position_tier if isinstance(tier, MarketPositionTier): tier = tier.value final_score = apply_resilience_modifier(raw_score, tier, is_international) # Direction direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types) all_factors = list(contributing) if pos_factors: all_factors.append(f"positive_types:{','.join(pos_factors)}") if neg_factors: all_factors.append(f"negative_types:{','.join(neg_factors)}") confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0) return MacroImpactRecord( event_id=event.event_id, company_id=profile.company_id, ticker="", macro_impact_score=round(min(final_score, 1.0), 6), impact_direction=direction, contributing_factors=all_factors, confidence=round(confidence, 6), computed_at=now, ) # --------------------------------------------------------------------------- # Default profile builder # --------------------------------------------------------------------------- def build_default_profile( sector: str, industry: str, market_cap_bucket: str, ) -> ExposureProfileSchema: """Build a default ExposureProfile for companies without manual profiles. Uses sector-based geographic revenue defaults and maps market_cap_bucket to market_position_tier: large_cap → global_leader mid_cap → multinational small_cap → regional micro_cap → domestic Args: sector: GICS sector name. industry: Industry name (used for commodity defaults). market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap. Returns: An ExposureProfileSchema with source='inferred'. """ tier = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value) geo_mix = _SECTOR_DEFAULT_GEO.get(sector, _DEFAULT_GEO) # Derive supply chain regions from geo mix keys supply_regions = list(geo_mix.keys()) # Derive commodities from sector/industry commodities = _infer_commodities(sector, industry) # Export dependency based on tier export_pct = { MarketPositionTier.GLOBAL_LEADER.value: 0.5, MarketPositionTier.MULTINATIONAL.value: 0.35, MarketPositionTier.REGIONAL.value: 0.15, MarketPositionTier.DOMESTIC.value: 0.05, }.get(tier, 0.15) return ExposureProfileSchema( company_id="", geographic_revenue_mix=dict(geo_mix), supply_chain_regions=supply_regions, key_input_commodities=commodities, regulatory_jurisdictions=list(geo_mix.keys())[:3], market_position_tier=MarketPositionTier(tier), export_dependency_pct=export_pct, source="inferred", confidence=0.5, version=1, ) def _infer_commodities(sector: str, industry: str) -> list[str]: """Infer key input commodities from sector and industry.""" sector_commodities: dict[str, list[str]] = { "Energy": ["crude_oil", "natural_gas"], "Materials": ["copper", "steel", "lithium"], "Industrials": ["steel", "copper"], "Information Technology": ["semiconductors", "lithium"], "Consumer Staples": ["wheat", "corn"], "Consumer Discretionary": ["steel", "semiconductors"], "Health Care": [], "Financials": [], "Communication Services": [], "Utilities": ["natural_gas"], "Real Estate": ["steel"], } return sector_commodities.get(sector, []) # --------------------------------------------------------------------------- # PostgreSQL persistence # --------------------------------------------------------------------------- async def persist_macro_impact_record( pool: asyncpg.Pool, record: MacroImpactRecord, ) -> str: """Persist a MacroImpactRecord to the macro_impact_records table. Returns the row UUID. """ row_id = await pool.fetchval( """INSERT INTO macro_impact_records (event_id, company_id, ticker, macro_impact_score, impact_direction, contributing_factors, confidence, computed_at) VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6::jsonb, $7, $8) RETURNING id""", record.event_id, record.company_id, record.ticker, record.macro_impact_score, record.impact_direction, json.dumps(record.contributing_factors), record.confidence, record.computed_at, ) logger.info( "Persisted macro impact record for event=%s company=%s score=%.4f direction=%s", record.event_id, record.company_id, record.macro_impact_score, record.impact_direction, ) return str(row_id) async def persist_macro_impact_records( pool: asyncpg.Pool, records: list[MacroImpactRecord], ) -> list[str]: """Persist multiple MacroImpactRecords. Returns list of row UUIDs.""" ids: list[str] = [] for record in records: if record.macro_impact_score > 0.0: row_id = await persist_macro_impact_record(pool, record) ids.append(row_id) return ids # --------------------------------------------------------------------------- # Low-confidence event exclusion (Requirements: 10.1) # --------------------------------------------------------------------------- def filter_low_confidence_events( events: list[GlobalEvent], confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD, ) -> list[GlobalEvent]: """Filter out events with confidence below the configurable threshold. Events with confidence below the threshold are excluded from macro impact computation and the exclusion reason is logged. Args: events: List of GlobalEvent classifications. confidence_threshold: Minimum confidence for inclusion (default 0.4). Returns: List of events that pass the confidence threshold. Requirements: 10.1 """ included: list[GlobalEvent] = [] for event in events: if event.confidence < confidence_threshold: logger.info( "Excluding low-confidence event %s: confidence=%.3f < threshold=%.3f", event.event_id, event.confidence, confidence_threshold, ) else: included.append(event) return included # --------------------------------------------------------------------------- # Accelerated decay for stale short-term events (Requirements: 10.2) # --------------------------------------------------------------------------- def compute_standard_recency_decay( age_hours: float, half_life_hours: float = 168.0, ) -> float: """Compute standard exponential recency decay. Args: age_hours: Age of the event in hours. half_life_hours: Half-life for the decay function (default 7 days). Returns: Decay factor in (0, 1]. """ import math if age_hours <= 0: return 1.0 return math.exp(-0.693 * age_hours / half_life_hours) def apply_accelerated_decay( age_hours: float, estimated_duration: str, staleness_hours: float = DEFAULT_SHORT_TERM_STALENESS_HOURS, half_life_hours: float = 168.0, ) -> float: """Apply accelerated decay for stale short-term events. For short_term events older than staleness_hours (default 48h), the effective weight is strictly less than standard recency decay. For non-short-term events or events within the staleness window, standard recency decay is applied. Args: age_hours: Age of the event in hours. estimated_duration: Event's estimated_duration field. staleness_hours: Hours after which short-term events get accelerated decay. half_life_hours: Half-life for the standard decay function. Returns: Effective signal weight in (0, 1]. Requirements: 10.2 """ standard_decay = compute_standard_recency_decay(age_hours, half_life_hours) if estimated_duration == "short_term" and age_hours > staleness_hours: # Apply accelerated decay: multiply standard decay by a factor < 1 accelerated = standard_decay * ACCELERATED_DECAY_MULTIPLIER logger.debug( "Accelerated decay for short_term event: age=%.1fh, " "standard=%.4f, accelerated=%.4f", age_hours, standard_decay, accelerated, ) return accelerated return standard_decay