feat: competitive intelligence & historical pattern matching layer
This commit is contained in:
@@ -0,0 +1,741 @@
|
||||
"""Interpolation engine — macro-to-company impact scoring.
|
||||
|
||||
Computes per-company macro impact scores by evaluating overlap between
|
||||
global event classifications and company exposure profiles. Produces
|
||||
MacroImpactRecord objects that feed into the aggregation engine as
|
||||
additional weighted signals.
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 3.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.extractor.event_classifier import GlobalEvent
|
||||
from services.shared.schemas import (
|
||||
ExposureProfileSchema,
|
||||
MarketPositionTier,
|
||||
SeverityLevel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("interpolation")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default configuration constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_CONFIDENCE_THRESHOLD = 0.4
|
||||
DEFAULT_SHORT_TERM_STALENESS_HOURS = 48
|
||||
ACCELERATED_DECAY_MULTIPLIER = 0.5 # applied on top of standard recency decay
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Severity weights
|
||||
SEVERITY_WEIGHTS: dict[str, float] = {
|
||||
SeverityLevel.CRITICAL.value: 1.0,
|
||||
SeverityLevel.HIGH.value: 0.75,
|
||||
SeverityLevel.MODERATE.value: 0.5,
|
||||
SeverityLevel.LOW.value: 0.25,
|
||||
}
|
||||
|
||||
# Component weights in the scoring formula
|
||||
GEO_WEIGHT = 0.35
|
||||
SUPPLY_WEIGHT = 0.25
|
||||
COMMODITY_WEIGHT = 0.25
|
||||
SECTOR_WEIGHT = 0.15
|
||||
|
||||
# Resilience modifiers for international events
|
||||
RESILIENCE_MODIFIERS: dict[str, float] = {
|
||||
MarketPositionTier.GLOBAL_LEADER.value: 0.7,
|
||||
MarketPositionTier.MULTINATIONAL.value: 0.85,
|
||||
MarketPositionTier.REGIONAL.value: 1.0,
|
||||
MarketPositionTier.DOMESTIC.value: 1.2,
|
||||
}
|
||||
|
||||
# Event types that are typically negative
|
||||
_NEGATIVE_EVENT_TYPES = frozenset({
|
||||
"supply_disruption",
|
||||
"cost_increase",
|
||||
"regulatory_pressure",
|
||||
"geopolitical_risk",
|
||||
"trade_barrier",
|
||||
})
|
||||
|
||||
# Event types that can be positive
|
||||
_POSITIVE_EVENT_TYPES = frozenset({
|
||||
"demand_shift",
|
||||
})
|
||||
|
||||
# Event types that can go either way
|
||||
_AMBIGUOUS_EVENT_TYPES = frozenset({
|
||||
"commodity_shock",
|
||||
"currency_impact",
|
||||
})
|
||||
|
||||
# Market cap bucket → market position tier mapping for default profiles
|
||||
_CAP_TO_TIER: dict[str, str] = {
|
||||
"large_cap": MarketPositionTier.GLOBAL_LEADER.value,
|
||||
"mid_cap": MarketPositionTier.MULTINATIONAL.value,
|
||||
"small_cap": MarketPositionTier.REGIONAL.value,
|
||||
"micro_cap": MarketPositionTier.DOMESTIC.value,
|
||||
}
|
||||
|
||||
# Sector-based default geographic revenue mixes
|
||||
_SECTOR_DEFAULT_GEO: dict[str, dict[str, float]] = {
|
||||
"Information Technology": {"US": 0.45, "CN": 0.15, "EU": 0.15, "JP": 0.10, "KR": 0.15},
|
||||
"Health Care": {"US": 0.50, "EU": 0.25, "JP": 0.10, "CN": 0.15},
|
||||
"Financials": {"US": 0.55, "EU": 0.20, "GB": 0.15, "JP": 0.10},
|
||||
"Energy": {"US": 0.30, "SA": 0.20, "RU": 0.15, "CA": 0.15, "AE": 0.20},
|
||||
"Materials": {"US": 0.25, "CN": 0.25, "AU": 0.20, "BR": 0.15, "IN": 0.15},
|
||||
"Industrials": {"US": 0.40, "DE": 0.15, "CN": 0.15, "JP": 0.15, "KR": 0.15},
|
||||
"Consumer Discretionary": {"US": 0.45, "CN": 0.20, "EU": 0.15, "JP": 0.10, "IN": 0.10},
|
||||
"Consumer Staples": {"US": 0.45, "EU": 0.20, "CN": 0.15, "IN": 0.10, "BR": 0.10},
|
||||
"Communication Services": {"US": 0.50, "CN": 0.15, "EU": 0.15, "JP": 0.10, "IN": 0.10},
|
||||
"Utilities": {"US": 0.70, "EU": 0.15, "JP": 0.15},
|
||||
"Real Estate": {"US": 0.60, "CN": 0.15, "EU": 0.15, "JP": 0.10},
|
||||
}
|
||||
|
||||
_DEFAULT_GEO = {"US": 0.50, "EU": 0.20, "CN": 0.15, "JP": 0.15}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MacroImpactRecord dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MacroImpactRecord:
|
||||
"""A computed macro impact score for a specific company-event pair."""
|
||||
|
||||
event_id: str = ""
|
||||
company_id: str = ""
|
||||
ticker: str = ""
|
||||
macro_impact_score: float = 0.0 # [0, 1]
|
||||
impact_direction: str = "neutral" # positive|negative|mixed
|
||||
contributing_factors: list[str] = field(default_factory=list)
|
||||
confidence: float = 0.5 # [0, 1]
|
||||
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlap computation functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_geographic_overlap(
|
||||
event_regions: list[str],
|
||||
revenue_mix: dict[str, float],
|
||||
) -> float:
|
||||
"""Compute geographic overlap using revenue percentage weighting.
|
||||
|
||||
For each event region that appears in the company's revenue mix,
|
||||
sum the revenue percentage. Returns a value in [0, 1].
|
||||
|
||||
Args:
|
||||
event_regions: Region codes from the global event.
|
||||
revenue_mix: Company's geographic_revenue_mix (region -> pct).
|
||||
|
||||
Returns:
|
||||
Sum of revenue percentages for overlapping regions, clamped to [0, 1].
|
||||
"""
|
||||
if not event_regions or not revenue_mix:
|
||||
return 0.0
|
||||
|
||||
event_set = {r.upper() for r in event_regions}
|
||||
overlap = 0.0
|
||||
for region, pct in revenue_mix.items():
|
||||
if region.upper() in event_set:
|
||||
overlap += pct
|
||||
|
||||
return min(max(overlap, 0.0), 1.0)
|
||||
|
||||
|
||||
def compute_supply_chain_overlap(
|
||||
event_regions: list[str],
|
||||
supply_regions: list[str],
|
||||
) -> float:
|
||||
"""Compute supply chain overlap using set intersection ratio.
|
||||
|
||||
Returns the fraction of the company's supply chain regions that
|
||||
overlap with the event's affected regions.
|
||||
|
||||
Args:
|
||||
event_regions: Region codes from the global event.
|
||||
supply_regions: Company's supply_chain_regions.
|
||||
|
||||
Returns:
|
||||
Intersection ratio in [0, 1]. 0.0 if supply_regions is empty.
|
||||
"""
|
||||
if not event_regions or not supply_regions:
|
||||
return 0.0
|
||||
|
||||
event_set = {r.upper() for r in event_regions}
|
||||
supply_set = {r.upper() for r in supply_regions}
|
||||
|
||||
intersection = event_set & supply_set
|
||||
return len(intersection) / len(supply_set)
|
||||
|
||||
|
||||
def compute_commodity_overlap(
|
||||
event_commodities: list[str],
|
||||
company_commodities: list[str],
|
||||
) -> float:
|
||||
"""Compute commodity overlap using set intersection ratio.
|
||||
|
||||
Returns the fraction of the company's key commodities that overlap
|
||||
with the event's affected commodities.
|
||||
|
||||
Args:
|
||||
event_commodities: Commodity identifiers from the global event.
|
||||
company_commodities: Company's key_input_commodities.
|
||||
|
||||
Returns:
|
||||
Intersection ratio in [0, 1]. 0.0 if company_commodities is empty.
|
||||
"""
|
||||
if not event_commodities or not company_commodities:
|
||||
return 0.0
|
||||
|
||||
event_set = {c.lower() for c in event_commodities}
|
||||
company_set = {c.lower() for c in company_commodities}
|
||||
|
||||
intersection = event_set & company_set
|
||||
return len(intersection) / len(company_set)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resilience modifier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_resilience_modifier(
|
||||
raw_score: float,
|
||||
tier: str,
|
||||
event_is_international: bool = True,
|
||||
) -> float:
|
||||
"""Apply a resilience modifier based on market position tier.
|
||||
|
||||
For international events, global leaders get a dampening factor (0.7)
|
||||
while domestic companies get an amplification factor (1.2).
|
||||
For domestic-only events, no modifier is applied.
|
||||
|
||||
Args:
|
||||
raw_score: The raw impact score before resilience adjustment.
|
||||
tier: Market position tier value.
|
||||
event_is_international: Whether the event affects multiple countries.
|
||||
|
||||
Returns:
|
||||
Modified score clamped to [0, 1].
|
||||
"""
|
||||
if not event_is_international:
|
||||
return min(max(raw_score, 0.0), 1.0)
|
||||
|
||||
modifier = RESILIENCE_MODIFIERS.get(tier, 1.0)
|
||||
return min(max(raw_score * modifier, 0.0), 1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Impact direction determination
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _determine_impact_direction(
|
||||
event_types: list[str],
|
||||
) -> tuple[str, list[str], list[str]]:
|
||||
"""Determine impact direction from event types.
|
||||
|
||||
Returns:
|
||||
Tuple of (direction, positive_factors, negative_factors).
|
||||
"""
|
||||
positive_factors: list[str] = []
|
||||
negative_factors: list[str] = []
|
||||
|
||||
for et in event_types:
|
||||
if et in _NEGATIVE_EVENT_TYPES:
|
||||
negative_factors.append(et)
|
||||
elif et in _POSITIVE_EVENT_TYPES:
|
||||
positive_factors.append(et)
|
||||
elif et in _AMBIGUOUS_EVENT_TYPES:
|
||||
# Ambiguous types contribute to both sides
|
||||
positive_factors.append(et)
|
||||
negative_factors.append(et)
|
||||
|
||||
has_positive = len(positive_factors) > 0
|
||||
has_negative = len(negative_factors) > 0
|
||||
|
||||
if has_positive and has_negative:
|
||||
return "mixed", positive_factors, negative_factors
|
||||
elif has_positive:
|
||||
return "positive", positive_factors, negative_factors
|
||||
elif has_negative:
|
||||
return "negative", positive_factors, negative_factors
|
||||
else:
|
||||
return "negative", positive_factors, negative_factors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core scoring function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_macro_impact(
|
||||
event: GlobalEvent,
|
||||
profile: ExposureProfileSchema,
|
||||
) -> MacroImpactRecord:
|
||||
"""Compute the macro impact of a global event on a company.
|
||||
|
||||
Scoring formula:
|
||||
raw_score = severity_weight * (
|
||||
0.35 * geographic_overlap +
|
||||
0.25 * supply_chain_overlap +
|
||||
0.25 * commodity_overlap +
|
||||
0.15 * sector_match
|
||||
)
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
Args:
|
||||
event: The classified global event.
|
||||
profile: The company's exposure profile.
|
||||
|
||||
Returns:
|
||||
A MacroImpactRecord with the computed score and metadata.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Compute overlaps
|
||||
geo_overlap = compute_geographic_overlap(
|
||||
event.affected_regions,
|
||||
profile.geographic_revenue_mix,
|
||||
)
|
||||
supply_overlap = compute_supply_chain_overlap(
|
||||
event.affected_regions,
|
||||
profile.supply_chain_regions,
|
||||
)
|
||||
commodity_overlap = compute_commodity_overlap(
|
||||
event.affected_commodities,
|
||||
profile.key_input_commodities,
|
||||
)
|
||||
|
||||
# Sector match: 1.0 if any event sector matches the company's sector
|
||||
# We check against the profile's regulatory_jurisdictions as a proxy,
|
||||
# but the real sector comes from the company data. For now, we use
|
||||
# a simple heuristic: check if any affected_sectors appear in the
|
||||
# profile's geographic_revenue_mix keys or supply_chain_regions.
|
||||
# The actual sector is not stored in ExposureProfileSchema, so we
|
||||
# check if any event sectors match. This will be 0.0 unless the
|
||||
# caller provides sector info through contributing_factors.
|
||||
sector_match = 0.0
|
||||
# We'll compute sector_match based on event sectors — the caller
|
||||
# should ensure the profile has relevant sector info. For the
|
||||
# default implementation, we always set sector_match to 0.0 here
|
||||
# and let the caller override if needed.
|
||||
|
||||
# Check zero-overlap case
|
||||
contributing = []
|
||||
if geo_overlap > 0:
|
||||
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
|
||||
if supply_overlap > 0:
|
||||
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
|
||||
if commodity_overlap > 0:
|
||||
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
|
||||
|
||||
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
|
||||
if total_overlap == 0.0:
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=0.0,
|
||||
impact_direction="neutral",
|
||||
contributing_factors=[],
|
||||
confidence=0.0,
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
# Severity weight
|
||||
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
|
||||
|
||||
# Raw score
|
||||
raw_score = severity_weight * (
|
||||
GEO_WEIGHT * geo_overlap
|
||||
+ SUPPLY_WEIGHT * supply_overlap
|
||||
+ COMMODITY_WEIGHT * commodity_overlap
|
||||
+ SECTOR_WEIGHT * sector_match
|
||||
)
|
||||
|
||||
# Determine if event is international (affects multiple regions)
|
||||
is_international = len(event.affected_regions) > 1
|
||||
|
||||
# Apply resilience modifier
|
||||
tier = profile.market_position_tier
|
||||
if isinstance(tier, MarketPositionTier):
|
||||
tier = tier.value
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
# Determine impact direction
|
||||
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
|
||||
|
||||
# Build contributing factors list
|
||||
all_factors = list(contributing)
|
||||
if pos_factors:
|
||||
all_factors.append(f"positive_types:{','.join(pos_factors)}")
|
||||
if neg_factors:
|
||||
all_factors.append(f"negative_types:{','.join(neg_factors)}")
|
||||
|
||||
# Confidence: combine event confidence with overlap strength
|
||||
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
|
||||
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=round(min(final_score, 1.0), 6),
|
||||
impact_direction=direction,
|
||||
contributing_factors=all_factors,
|
||||
confidence=round(confidence, 6),
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
|
||||
def compute_macro_impact_with_sector(
|
||||
event: GlobalEvent,
|
||||
profile: ExposureProfileSchema,
|
||||
company_sector: str = "",
|
||||
) -> MacroImpactRecord:
|
||||
"""Compute macro impact with explicit sector matching.
|
||||
|
||||
Like compute_macro_impact but accepts a company_sector parameter
|
||||
for proper sector_match computation.
|
||||
|
||||
Args:
|
||||
event: The classified global event.
|
||||
profile: The company's exposure profile.
|
||||
company_sector: The company's GICS sector name.
|
||||
|
||||
Returns:
|
||||
A MacroImpactRecord with the computed score and metadata.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Compute overlaps
|
||||
geo_overlap = compute_geographic_overlap(
|
||||
event.affected_regions,
|
||||
profile.geographic_revenue_mix,
|
||||
)
|
||||
supply_overlap = compute_supply_chain_overlap(
|
||||
event.affected_regions,
|
||||
profile.supply_chain_regions,
|
||||
)
|
||||
commodity_overlap = compute_commodity_overlap(
|
||||
event.affected_commodities,
|
||||
profile.key_input_commodities,
|
||||
)
|
||||
|
||||
# Sector match
|
||||
sector_match = 0.0
|
||||
if company_sector and event.affected_sectors:
|
||||
company_sector_lower = company_sector.lower().strip()
|
||||
for es in event.affected_sectors:
|
||||
if es.lower().strip() == company_sector_lower:
|
||||
sector_match = 1.0
|
||||
break
|
||||
|
||||
# Contributing factors
|
||||
contributing: list[str] = []
|
||||
if geo_overlap > 0:
|
||||
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
|
||||
if supply_overlap > 0:
|
||||
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
|
||||
if commodity_overlap > 0:
|
||||
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
|
||||
if sector_match > 0:
|
||||
contributing.append(f"sector_match:{company_sector}")
|
||||
|
||||
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
|
||||
if total_overlap == 0.0:
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=0.0,
|
||||
impact_direction="neutral",
|
||||
contributing_factors=[],
|
||||
confidence=0.0,
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
# Severity weight
|
||||
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
|
||||
|
||||
# Raw score
|
||||
raw_score = severity_weight * (
|
||||
GEO_WEIGHT * geo_overlap
|
||||
+ SUPPLY_WEIGHT * supply_overlap
|
||||
+ COMMODITY_WEIGHT * commodity_overlap
|
||||
+ SECTOR_WEIGHT * sector_match
|
||||
)
|
||||
|
||||
# International check
|
||||
is_international = len(event.affected_regions) > 1
|
||||
|
||||
# Resilience modifier
|
||||
tier = profile.market_position_tier
|
||||
if isinstance(tier, MarketPositionTier):
|
||||
tier = tier.value
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
# Direction
|
||||
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
|
||||
|
||||
all_factors = list(contributing)
|
||||
if pos_factors:
|
||||
all_factors.append(f"positive_types:{','.join(pos_factors)}")
|
||||
if neg_factors:
|
||||
all_factors.append(f"negative_types:{','.join(neg_factors)}")
|
||||
|
||||
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
|
||||
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=round(min(final_score, 1.0), 6),
|
||||
impact_direction=direction,
|
||||
contributing_factors=all_factors,
|
||||
confidence=round(confidence, 6),
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default profile builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_default_profile(
|
||||
sector: str,
|
||||
industry: str,
|
||||
market_cap_bucket: str,
|
||||
) -> ExposureProfileSchema:
|
||||
"""Build a default ExposureProfile for companies without manual profiles.
|
||||
|
||||
Uses sector-based geographic revenue defaults and maps market_cap_bucket
|
||||
to market_position_tier:
|
||||
large_cap → global_leader
|
||||
mid_cap → multinational
|
||||
small_cap → regional
|
||||
micro_cap → domestic
|
||||
|
||||
Args:
|
||||
sector: GICS sector name.
|
||||
industry: Industry name (used for commodity defaults).
|
||||
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
|
||||
|
||||
Returns:
|
||||
An ExposureProfileSchema with source='inferred'.
|
||||
"""
|
||||
tier = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
|
||||
geo_mix = _SECTOR_DEFAULT_GEO.get(sector, _DEFAULT_GEO)
|
||||
|
||||
# Derive supply chain regions from geo mix keys
|
||||
supply_regions = list(geo_mix.keys())
|
||||
|
||||
# Derive commodities from sector/industry
|
||||
commodities = _infer_commodities(sector, industry)
|
||||
|
||||
# Export dependency based on tier
|
||||
export_pct = {
|
||||
MarketPositionTier.GLOBAL_LEADER.value: 0.5,
|
||||
MarketPositionTier.MULTINATIONAL.value: 0.35,
|
||||
MarketPositionTier.REGIONAL.value: 0.15,
|
||||
MarketPositionTier.DOMESTIC.value: 0.05,
|
||||
}.get(tier, 0.15)
|
||||
|
||||
return ExposureProfileSchema(
|
||||
company_id="",
|
||||
geographic_revenue_mix=dict(geo_mix),
|
||||
supply_chain_regions=supply_regions,
|
||||
key_input_commodities=commodities,
|
||||
regulatory_jurisdictions=list(geo_mix.keys())[:3],
|
||||
market_position_tier=MarketPositionTier(tier),
|
||||
export_dependency_pct=export_pct,
|
||||
source="inferred",
|
||||
confidence=0.5,
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
def _infer_commodities(sector: str, industry: str) -> list[str]:
|
||||
"""Infer key input commodities from sector and industry."""
|
||||
sector_commodities: dict[str, list[str]] = {
|
||||
"Energy": ["crude_oil", "natural_gas"],
|
||||
"Materials": ["copper", "steel", "lithium"],
|
||||
"Industrials": ["steel", "copper"],
|
||||
"Information Technology": ["semiconductors", "lithium"],
|
||||
"Consumer Staples": ["wheat", "corn"],
|
||||
"Consumer Discretionary": ["steel", "semiconductors"],
|
||||
"Health Care": [],
|
||||
"Financials": [],
|
||||
"Communication Services": [],
|
||||
"Utilities": ["natural_gas"],
|
||||
"Real Estate": ["steel"],
|
||||
}
|
||||
return sector_commodities.get(sector, [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def persist_macro_impact_record(
|
||||
pool: asyncpg.Pool,
|
||||
record: MacroImpactRecord,
|
||||
) -> str:
|
||||
"""Persist a MacroImpactRecord to the macro_impact_records table.
|
||||
|
||||
Returns the row UUID.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO macro_impact_records
|
||||
(event_id, company_id, ticker, macro_impact_score,
|
||||
impact_direction, contributing_factors, confidence, computed_at)
|
||||
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6::jsonb, $7, $8)
|
||||
RETURNING id""",
|
||||
record.event_id,
|
||||
record.company_id,
|
||||
record.ticker,
|
||||
record.macro_impact_score,
|
||||
record.impact_direction,
|
||||
json.dumps(record.contributing_factors),
|
||||
record.confidence,
|
||||
record.computed_at,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted macro impact record for event=%s company=%s score=%.4f direction=%s",
|
||||
record.event_id,
|
||||
record.company_id,
|
||||
record.macro_impact_score,
|
||||
record.impact_direction,
|
||||
)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
async def persist_macro_impact_records(
|
||||
pool: asyncpg.Pool,
|
||||
records: list[MacroImpactRecord],
|
||||
) -> list[str]:
|
||||
"""Persist multiple MacroImpactRecords. Returns list of row UUIDs."""
|
||||
ids: list[str] = []
|
||||
for record in records:
|
||||
if record.macro_impact_score > 0.0:
|
||||
row_id = await persist_macro_impact_record(pool, record)
|
||||
ids.append(row_id)
|
||||
return ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-confidence event exclusion (Requirements: 10.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def filter_low_confidence_events(
|
||||
events: list[GlobalEvent],
|
||||
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
|
||||
) -> list[GlobalEvent]:
|
||||
"""Filter out events with confidence below the configurable threshold.
|
||||
|
||||
Events with confidence below the threshold are excluded from macro
|
||||
impact computation and the exclusion reason is logged.
|
||||
|
||||
Args:
|
||||
events: List of GlobalEvent classifications.
|
||||
confidence_threshold: Minimum confidence for inclusion (default 0.4).
|
||||
|
||||
Returns:
|
||||
List of events that pass the confidence threshold.
|
||||
|
||||
Requirements: 10.1
|
||||
"""
|
||||
included: list[GlobalEvent] = []
|
||||
for event in events:
|
||||
if event.confidence < confidence_threshold:
|
||||
logger.info(
|
||||
"Excluding low-confidence event %s: confidence=%.3f < threshold=%.3f",
|
||||
event.event_id,
|
||||
event.confidence,
|
||||
confidence_threshold,
|
||||
)
|
||||
else:
|
||||
included.append(event)
|
||||
return included
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Accelerated decay for stale short-term events (Requirements: 10.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_standard_recency_decay(
|
||||
age_hours: float,
|
||||
half_life_hours: float = 168.0,
|
||||
) -> float:
|
||||
"""Compute standard exponential recency decay.
|
||||
|
||||
Args:
|
||||
age_hours: Age of the event in hours.
|
||||
half_life_hours: Half-life for the decay function (default 7 days).
|
||||
|
||||
Returns:
|
||||
Decay factor in (0, 1].
|
||||
"""
|
||||
import math
|
||||
if age_hours <= 0:
|
||||
return 1.0
|
||||
return math.exp(-0.693 * age_hours / half_life_hours)
|
||||
|
||||
|
||||
def apply_accelerated_decay(
|
||||
age_hours: float,
|
||||
estimated_duration: str,
|
||||
staleness_hours: float = DEFAULT_SHORT_TERM_STALENESS_HOURS,
|
||||
half_life_hours: float = 168.0,
|
||||
) -> float:
|
||||
"""Apply accelerated decay for stale short-term events.
|
||||
|
||||
For short_term events older than staleness_hours (default 48h),
|
||||
the effective weight is strictly less than standard recency decay.
|
||||
|
||||
For non-short-term events or events within the staleness window,
|
||||
standard recency decay is applied.
|
||||
|
||||
Args:
|
||||
age_hours: Age of the event in hours.
|
||||
estimated_duration: Event's estimated_duration field.
|
||||
staleness_hours: Hours after which short-term events get accelerated decay.
|
||||
half_life_hours: Half-life for the standard decay function.
|
||||
|
||||
Returns:
|
||||
Effective signal weight in (0, 1].
|
||||
|
||||
Requirements: 10.2
|
||||
"""
|
||||
standard_decay = compute_standard_recency_decay(age_hours, half_life_hours)
|
||||
|
||||
if estimated_duration == "short_term" and age_hours > staleness_hours:
|
||||
# Apply accelerated decay: multiply standard decay by a factor < 1
|
||||
accelerated = standard_decay * ACCELERATED_DECAY_MULTIPLIER
|
||||
logger.debug(
|
||||
"Accelerated decay for short_term event: age=%.1fh, "
|
||||
"standard=%.4f, accelerated=%.4f",
|
||||
age_hours, standard_decay, accelerated,
|
||||
)
|
||||
return accelerated
|
||||
|
||||
return standard_decay
|
||||
Reference in New Issue
Block a user