Files
stonks-oracle/services/aggregation/interpolation.py
T
Celes Renata 4e010bc048
ci/woodpecker/push/test Pipeline was successful
ci/woodpecker/push/build-1 Pipeline was successful
ci/woodpecker/push/build-2 Pipeline was successful
ci/woodpecker/push/build-3 Pipeline was successful
ci/woodpecker/push/finalize Pipeline was successful
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
feat: signal math upgrade — probabilistic, regime-aware scoring pipeline
Implement full probabilistic signal processing pipeline gated behind
probabilistic_scoring_enabled feature flag in risk_configs:

- Bayesian log-likelihood accumulator with Beta posterior and entropy
- Regime detector (trend-following, panic, mean-reversion, uncertainty)
- Source accuracy tracker with per-source historical prediction accuracy
- Sigmoid confidence gate replacing binary gate
- Information gain surprise weighting for rare events
- Adaptive recency decay with event-specific half-lives
- Regime multiplier replacing market context multiplier
- Weighted disagreement entropy for contradiction detection
- Multiplicative macro exposure with conditional integration
- Graph-distance attenuated competitive signal propagation
- Exponentially weighted momentum with volatility scaling
- Expected value recommendation gate

All changes backward-compatible: flag=false preserves exact current behavior.
New outputs stored in existing JSONB columns (no schema changes except
source_accuracy table via migration 034).

Tests: 26 property-based tests (14 correctness properties), 99 unit tests,
1789 total tests passing with zero regressions.
2026-04-29 11:41:48 +00:00

958 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Interpolation engine — macro-to-company impact scoring.
Computes per-company macro impact scores by evaluating overlap between
global event classifications and company exposure profiles. Produces
MacroImpactRecord objects that feed into the aggregation engine as
additional weighted signals.
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 3.2
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.extractor.event_classifier import GlobalEvent
from services.shared.schemas import (
ExposureProfileSchema,
MarketPositionTier,
SeverityLevel,
)
logger = logging.getLogger("interpolation")
# ---------------------------------------------------------------------------
# Default configuration constants
# ---------------------------------------------------------------------------
DEFAULT_CONFIDENCE_THRESHOLD = 0.4
DEFAULT_SHORT_TERM_STALENESS_HOURS = 48
ACCELERATED_DECAY_MULTIPLIER = 0.5 # applied on top of standard recency decay
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Severity weights
SEVERITY_WEIGHTS: dict[str, float] = {
SeverityLevel.CRITICAL.value: 1.0,
SeverityLevel.HIGH.value: 0.75,
SeverityLevel.MODERATE.value: 0.5,
SeverityLevel.LOW.value: 0.25,
}
# Component weights in the scoring formula
GEO_WEIGHT = 0.35
SUPPLY_WEIGHT = 0.25
COMMODITY_WEIGHT = 0.25
SECTOR_WEIGHT = 0.15
# Resilience modifiers for international events
RESILIENCE_MODIFIERS: dict[str, float] = {
MarketPositionTier.GLOBAL_LEADER.value: 0.7,
MarketPositionTier.MULTINATIONAL.value: 0.85,
MarketPositionTier.REGIONAL.value: 1.0,
MarketPositionTier.DOMESTIC.value: 1.2,
}
# Event types that are typically negative
_NEGATIVE_EVENT_TYPES = frozenset({
"supply_disruption",
"cost_increase",
"regulatory_pressure",
"geopolitical_risk",
"trade_barrier",
})
# Event types that can be positive
_POSITIVE_EVENT_TYPES = frozenset({
"demand_shift",
})
# Event types that can go either way
_AMBIGUOUS_EVENT_TYPES = frozenset({
"commodity_shock",
"currency_impact",
})
# Market cap bucket → market position tier mapping for default profiles
_CAP_TO_TIER: dict[str, str] = {
"large_cap": MarketPositionTier.GLOBAL_LEADER.value,
"mid_cap": MarketPositionTier.MULTINATIONAL.value,
"small_cap": MarketPositionTier.REGIONAL.value,
"micro_cap": MarketPositionTier.DOMESTIC.value,
}
# Sector-based default geographic revenue mixes
_SECTOR_DEFAULT_GEO: dict[str, dict[str, float]] = {
"Information Technology": {"US": 0.45, "CN": 0.15, "EU": 0.15, "JP": 0.10, "KR": 0.15},
"Health Care": {"US": 0.50, "EU": 0.25, "JP": 0.10, "CN": 0.15},
"Financials": {"US": 0.55, "EU": 0.20, "GB": 0.15, "JP": 0.10},
"Energy": {"US": 0.30, "SA": 0.20, "RU": 0.15, "CA": 0.15, "AE": 0.20},
"Materials": {"US": 0.25, "CN": 0.25, "AU": 0.20, "BR": 0.15, "IN": 0.15},
"Industrials": {"US": 0.40, "DE": 0.15, "CN": 0.15, "JP": 0.15, "KR": 0.15},
"Consumer Discretionary": {"US": 0.45, "CN": 0.20, "EU": 0.15, "JP": 0.10, "IN": 0.10},
"Consumer Staples": {"US": 0.45, "EU": 0.20, "CN": 0.15, "IN": 0.10, "BR": 0.10},
"Communication Services": {"US": 0.50, "CN": 0.15, "EU": 0.15, "JP": 0.10, "IN": 0.10},
"Utilities": {"US": 0.70, "EU": 0.15, "JP": 0.15},
"Real Estate": {"US": 0.60, "CN": 0.15, "EU": 0.15, "JP": 0.10},
}
_DEFAULT_GEO = {"US": 0.50, "EU": 0.20, "CN": 0.15, "JP": 0.15}
# ---------------------------------------------------------------------------
# MacroImpactRecord dataclass
# ---------------------------------------------------------------------------
@dataclass
class MacroImpactRecord:
"""A computed macro impact score for a specific company-event pair."""
event_id: str = ""
company_id: str = ""
ticker: str = ""
macro_impact_score: float = 0.0 # [0, 1]
impact_direction: str = "neutral" # positive|negative|mixed
contributing_factors: list[str] = field(default_factory=list)
confidence: float = 0.5 # [0, 1]
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
# ---------------------------------------------------------------------------
# Overlap computation functions
# ---------------------------------------------------------------------------
def compute_geographic_overlap(
event_regions: list[str],
revenue_mix: dict[str, float],
) -> float:
"""Compute geographic overlap using revenue percentage weighting.
For each event region that appears in the company's revenue mix,
sum the revenue percentage. Returns a value in [0, 1].
Args:
event_regions: Region codes from the global event.
revenue_mix: Company's geographic_revenue_mix (region -> pct).
Returns:
Sum of revenue percentages for overlapping regions, clamped to [0, 1].
"""
if not event_regions or not revenue_mix:
return 0.0
event_set = {r.upper() for r in event_regions}
overlap = 0.0
for region, pct in revenue_mix.items():
if region.upper() in event_set:
overlap += pct
return min(max(overlap, 0.0), 1.0)
def compute_supply_chain_overlap(
event_regions: list[str],
supply_regions: list[str],
) -> float:
"""Compute supply chain overlap using set intersection ratio.
Returns the fraction of the company's supply chain regions that
overlap with the event's affected regions.
Args:
event_regions: Region codes from the global event.
supply_regions: Company's supply_chain_regions.
Returns:
Intersection ratio in [0, 1]. 0.0 if supply_regions is empty.
"""
if not event_regions or not supply_regions:
return 0.0
event_set = {r.upper() for r in event_regions}
supply_set = {r.upper() for r in supply_regions}
intersection = event_set & supply_set
return len(intersection) / len(supply_set)
def compute_commodity_overlap(
event_commodities: list[str],
company_commodities: list[str],
) -> float:
"""Compute commodity overlap using set intersection ratio.
Returns the fraction of the company's key commodities that overlap
with the event's affected commodities.
Args:
event_commodities: Commodity identifiers from the global event.
company_commodities: Company's key_input_commodities.
Returns:
Intersection ratio in [0, 1]. 0.0 if company_commodities is empty.
"""
if not event_commodities or not company_commodities:
return 0.0
event_set = {c.lower() for c in event_commodities}
company_set = {c.lower() for c in company_commodities}
intersection = event_set & company_set
return len(intersection) / len(company_set)
# ---------------------------------------------------------------------------
# Resilience modifier
# ---------------------------------------------------------------------------
def apply_resilience_modifier(
raw_score: float,
tier: str,
event_is_international: bool = True,
) -> float:
"""Apply a resilience modifier based on market position tier.
For international events, global leaders get a dampening factor (0.7)
while domestic companies get an amplification factor (1.2).
For domestic-only events, no modifier is applied.
Args:
raw_score: The raw impact score before resilience adjustment.
tier: Market position tier value.
event_is_international: Whether the event affects multiple countries.
Returns:
Modified score clamped to [0, 1].
"""
if not event_is_international:
return min(max(raw_score, 0.0), 1.0)
modifier = RESILIENCE_MODIFIERS.get(tier, 1.0)
return min(max(raw_score * modifier, 0.0), 1.0)
# ---------------------------------------------------------------------------
# Impact direction determination
# ---------------------------------------------------------------------------
def _determine_impact_direction(
event_types: list[str],
) -> tuple[str, list[str], list[str]]:
"""Determine impact direction from event types.
Returns:
Tuple of (direction, positive_factors, negative_factors).
"""
positive_factors: list[str] = []
negative_factors: list[str] = []
for et in event_types:
if et in _NEGATIVE_EVENT_TYPES:
negative_factors.append(et)
elif et in _POSITIVE_EVENT_TYPES:
positive_factors.append(et)
elif et in _AMBIGUOUS_EVENT_TYPES:
# Ambiguous types contribute to both sides
positive_factors.append(et)
negative_factors.append(et)
has_positive = len(positive_factors) > 0
has_negative = len(negative_factors) > 0
if has_positive and has_negative:
return "mixed", positive_factors, negative_factors
elif has_positive:
return "positive", positive_factors, negative_factors
elif has_negative:
return "negative", positive_factors, negative_factors
else:
return "negative", positive_factors, negative_factors
# ---------------------------------------------------------------------------
# Core scoring function
# ---------------------------------------------------------------------------
def _compute_multiplicative_exposure(
geo_overlap: float,
supply_overlap: float,
commodity_overlap: float,
sector_match: float,
) -> float:
"""Compute multiplicative compounding exposure.
Formula: 1 - Π_k(1 - w_k · O_k)
Multi-dimensional exposure compounds — a company exposed across
multiple dimensions receives higher impact than simple addition.
Returns a value in [0, ~0.724] (max when all overlaps are 1.0).
Requirements: 10.1, 10.4, 10.7
"""
product = (
(1.0 - GEO_WEIGHT * geo_overlap)
* (1.0 - SUPPLY_WEIGHT * supply_overlap)
* (1.0 - COMMODITY_WEIGHT * commodity_overlap)
* (1.0 - SECTOR_WEIGHT * sector_match)
)
return 1.0 - product
def _compute_linear_exposure(
geo_overlap: float,
supply_overlap: float,
commodity_overlap: float,
sector_match: float,
) -> float:
"""Compute linear weighted-sum exposure (original heuristic formula).
Formula: w_geo·O_geo + w_supply·O_supply + w_commodity·O_commodity + w_sector·O_sector
Returns a value in [0, 1].
"""
return (
GEO_WEIGHT * geo_overlap
+ SUPPLY_WEIGHT * supply_overlap
+ COMMODITY_WEIGHT * commodity_overlap
+ SECTOR_WEIGHT * sector_match
)
def compute_macro_impact(
event: GlobalEvent,
profile: ExposureProfileSchema,
*,
probabilistic: bool = False,
) -> MacroImpactRecord:
"""Compute the macro impact of a global event on a company.
When ``probabilistic=False`` (default), uses the linear weighted-sum:
raw_score = severity_weight * (
0.35 * geographic_overlap +
0.25 * supply_chain_overlap +
0.25 * commodity_overlap +
0.15 * sector_match
)
When ``probabilistic=True``, uses multiplicative compounding exposure:
raw_score = severity_weight * (1 - Π_k(1 - w_k · O_k))
In both modes, the resilience modifier is applied after the raw score.
Args:
event: The classified global event.
profile: The company's exposure profile.
probabilistic: Use multiplicative formula when True.
Returns:
A MacroImpactRecord with the computed score and metadata.
Requirements: 10.1, 10.2, 10.3, 10.4, 10.5, 10.6
"""
now = datetime.now(timezone.utc)
# Compute overlaps
geo_overlap = compute_geographic_overlap(
event.affected_regions,
profile.geographic_revenue_mix,
)
supply_overlap = compute_supply_chain_overlap(
event.affected_regions,
profile.supply_chain_regions,
)
commodity_overlap = compute_commodity_overlap(
event.affected_commodities,
profile.key_input_commodities,
)
# Sector match: 1.0 if any event sector matches the company's sector
# We check against the profile's regulatory_jurisdictions as a proxy,
# but the real sector comes from the company data. For now, we use
# a simple heuristic: check if any affected_sectors appear in the
# profile's geographic_revenue_mix keys or supply_chain_regions.
# The actual sector is not stored in ExposureProfileSchema, so we
# check if any event sectors match. This will be 0.0 unless the
# caller provides sector info through contributing_factors.
sector_match = 0.0
# We'll compute sector_match based on event sectors — the caller
# should ensure the profile has relevant sector info. For the
# default implementation, we always set sector_match to 0.0 here
# and let the caller override if needed.
# Check zero-overlap case
contributing = []
if geo_overlap > 0:
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
if supply_overlap > 0:
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
if commodity_overlap > 0:
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
if total_overlap == 0.0:
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=0.0,
impact_direction="neutral",
contributing_factors=[],
confidence=0.0,
computed_at=now,
)
# Severity weight
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
# Raw score: multiplicative or linear depending on mode
if probabilistic:
exposure = _compute_multiplicative_exposure(
geo_overlap, supply_overlap, commodity_overlap, sector_match,
)
else:
exposure = _compute_linear_exposure(
geo_overlap, supply_overlap, commodity_overlap, sector_match,
)
raw_score = severity_weight * exposure
# Determine if event is international (affects multiple regions)
is_international = len(event.affected_regions) > 1
# Apply resilience modifier
tier = profile.market_position_tier
if isinstance(tier, MarketPositionTier):
tier = tier.value
final_score = apply_resilience_modifier(raw_score, tier, is_international)
# Determine impact direction
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
# Build contributing factors list
all_factors = list(contributing)
if pos_factors:
all_factors.append(f"positive_types:{','.join(pos_factors)}")
if neg_factors:
all_factors.append(f"negative_types:{','.join(neg_factors)}")
# Confidence: combine event confidence with overlap strength
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=round(min(final_score, 1.0), 6),
impact_direction=direction,
contributing_factors=all_factors,
confidence=round(confidence, 6),
computed_at=now,
)
def compute_macro_impact_with_sector(
event: GlobalEvent,
profile: ExposureProfileSchema,
company_sector: str = "",
*,
probabilistic: bool = False,
) -> MacroImpactRecord:
"""Compute macro impact with explicit sector matching.
Like compute_macro_impact but accepts a company_sector parameter
for proper sector_match computation.
When ``probabilistic=True``, uses multiplicative compounding exposure.
When ``probabilistic=False``, uses the original linear weighted sum.
Args:
event: The classified global event.
profile: The company's exposure profile.
company_sector: The company's GICS sector name.
probabilistic: Use multiplicative formula when True.
Returns:
A MacroImpactRecord with the computed score and metadata.
Requirements: 10.1, 10.2, 10.3, 10.4, 10.5, 10.6
"""
now = datetime.now(timezone.utc)
# Compute overlaps
geo_overlap = compute_geographic_overlap(
event.affected_regions,
profile.geographic_revenue_mix,
)
supply_overlap = compute_supply_chain_overlap(
event.affected_regions,
profile.supply_chain_regions,
)
commodity_overlap = compute_commodity_overlap(
event.affected_commodities,
profile.key_input_commodities,
)
# Sector match
sector_match = 0.0
if company_sector and event.affected_sectors:
company_sector_lower = company_sector.lower().strip()
for es in event.affected_sectors:
if es.lower().strip() == company_sector_lower:
sector_match = 1.0
break
# Contributing factors
contributing: list[str] = []
if geo_overlap > 0:
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
if supply_overlap > 0:
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
if commodity_overlap > 0:
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
if sector_match > 0:
contributing.append(f"sector_match:{company_sector}")
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
if total_overlap == 0.0:
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=0.0,
impact_direction="neutral",
contributing_factors=[],
confidence=0.0,
computed_at=now,
)
# Severity weight
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
# Raw score: multiplicative or linear depending on mode
if probabilistic:
exposure = _compute_multiplicative_exposure(
geo_overlap, supply_overlap, commodity_overlap, sector_match,
)
else:
exposure = _compute_linear_exposure(
geo_overlap, supply_overlap, commodity_overlap, sector_match,
)
raw_score = severity_weight * exposure
# International check
is_international = len(event.affected_regions) > 1
# Resilience modifier
tier = profile.market_position_tier
if isinstance(tier, MarketPositionTier):
tier = tier.value
final_score = apply_resilience_modifier(raw_score, tier, is_international)
# Direction
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
all_factors = list(contributing)
if pos_factors:
all_factors.append(f"positive_types:{','.join(pos_factors)}")
if neg_factors:
all_factors.append(f"negative_types:{','.join(neg_factors)}")
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=round(min(final_score, 1.0), 6),
impact_direction=direction,
contributing_factors=all_factors,
confidence=round(confidence, 6),
computed_at=now,
)
# ---------------------------------------------------------------------------
# Default profile builder
# ---------------------------------------------------------------------------
def build_default_profile(
sector: str,
industry: str,
market_cap_bucket: str,
) -> ExposureProfileSchema:
"""Build a default ExposureProfile for companies without manual profiles.
Uses sector-based geographic revenue defaults and maps market_cap_bucket
to market_position_tier:
large_cap → global_leader
mid_cap → multinational
small_cap → regional
micro_cap → domestic
Args:
sector: GICS sector name.
industry: Industry name (used for commodity defaults).
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
Returns:
An ExposureProfileSchema with source='inferred'.
"""
tier = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
geo_mix = _SECTOR_DEFAULT_GEO.get(sector, _DEFAULT_GEO)
# Derive supply chain regions from geo mix keys
supply_regions = list(geo_mix.keys())
# Derive commodities from sector/industry
commodities = _infer_commodities(sector, industry)
# Export dependency based on tier
export_pct = {
MarketPositionTier.GLOBAL_LEADER.value: 0.5,
MarketPositionTier.MULTINATIONAL.value: 0.35,
MarketPositionTier.REGIONAL.value: 0.15,
MarketPositionTier.DOMESTIC.value: 0.05,
}.get(tier, 0.15)
return ExposureProfileSchema(
company_id="",
geographic_revenue_mix=dict(geo_mix),
supply_chain_regions=supply_regions,
key_input_commodities=commodities,
regulatory_jurisdictions=list(geo_mix.keys())[:3],
market_position_tier=MarketPositionTier(tier),
export_dependency_pct=export_pct,
source="inferred",
confidence=0.5,
version=1,
)
def _infer_commodities(sector: str, industry: str) -> list[str]:
"""Infer key input commodities from sector and industry."""
sector_commodities: dict[str, list[str]] = {
"Energy": ["crude_oil", "natural_gas"],
"Materials": ["copper", "steel", "lithium"],
"Industrials": ["steel", "copper"],
"Information Technology": ["semiconductors", "lithium"],
"Consumer Staples": ["wheat", "corn"],
"Consumer Discretionary": ["steel", "semiconductors"],
"Health Care": [],
"Financials": [],
"Communication Services": [],
"Utilities": ["natural_gas"],
"Real Estate": ["steel"],
}
return sector_commodities.get(sector, [])
# ---------------------------------------------------------------------------
# Conditional macro signal integration (Requirements: 11.111.5)
# ---------------------------------------------------------------------------
def compute_conditional_macro_modifier(
company_strength: float,
company_direction: str,
macro_impact: float,
macro_direction: str,
) -> float:
"""Compute the multiplicative macro modifier for conditional integration.
When both company and macro signals exist, macro acts as a modifier:
S_adjusted = S_company · clamp(1 + M_macro · sign_alignment, 0.5, 1.5)
sign_alignment is +1 when macro and company agree in direction,
-1 when they disagree.
Args:
company_strength: The company-level signal strength (absolute).
company_direction: Company trend direction (bullish/bearish/neutral/mixed).
macro_impact: Normalized macro impact score in [0, 1].
macro_direction: Macro impact direction (positive/negative/mixed/neutral).
Returns:
The multiplicative modifier in [0.5, 1.5].
Requirements: 11.1, 11.2
"""
# Determine sign alignment between company and macro directions
_DIRECTION_SIGN = {
"bullish": 1,
"positive": 1,
"bearish": -1,
"negative": -1,
}
company_sign = _DIRECTION_SIGN.get(company_direction, 0)
macro_sign = _DIRECTION_SIGN.get(macro_direction, 0)
if company_sign == 0 or macro_sign == 0:
# Neutral or mixed directions — no alignment signal
sign_alignment = 0.0
elif company_sign == macro_sign:
sign_alignment = 1.0
else:
sign_alignment = -1.0
raw_modifier = 1.0 + macro_impact * sign_alignment
return max(0.5, min(1.5, raw_modifier))
def integrate_macro_signals(
company_signals: list,
macro_signals: list,
company_direction: str,
macro_impacts: list,
ticker: str = "",
*,
probabilistic: bool = False,
macro_signal_weight: float = 0.3,
) -> tuple[list, float]:
"""Integrate macro signals with company signals.
When ``probabilistic=True``:
- Both exist: apply macro as multiplicative modifier on company signals
- Only macro: fall back to additive behavior with weight 0.3
- Only company: use modifier = 1.0 (no change)
When ``probabilistic=False``:
- Preserve current additive merge behavior (concatenate lists)
Args:
company_signals: WeightedSignal list from company layer.
macro_signals: WeightedSignal list from macro layer.
company_direction: Derived company trend direction string.
macro_impacts: List of MacroImpactRecord or similar with
macro_impact_score and impact_direction attributes.
ticker: Ticker symbol for logging.
probabilistic: Use conditional modifier when True.
macro_signal_weight: Weight for macro-only fallback (default 0.3).
Returns:
Tuple of (merged_signals, macro_modifier_applied).
macro_modifier_applied is 1.0 when no modifier was used.
Requirements: 11.1, 11.2, 11.3, 11.4, 11.5
"""
if not probabilistic:
# Heuristic mode: simple additive merge (current behavior)
merged = list(company_signals) + list(macro_signals)
return merged, 1.0
has_company = len(company_signals) > 0
has_macro = len(macro_signals) > 0
if has_company and has_macro:
# Compute average macro impact and dominant direction
avg_macro_impact = 0.0
direction_counts: dict[str, float] = {}
for mir in macro_impacts:
score = getattr(mir, "macro_impact_score", 0.0)
direction = getattr(mir, "impact_direction", "neutral")
avg_macro_impact += score
direction_counts[direction] = direction_counts.get(direction, 0.0) + score
if macro_impacts:
avg_macro_impact /= len(macro_impacts)
# Dominant macro direction by total impact weight
macro_direction = max(direction_counts, key=direction_counts.get) if direction_counts else "neutral"
modifier = compute_conditional_macro_modifier(
company_strength=0.0, # not used in current formula
company_direction=company_direction,
macro_impact=avg_macro_impact,
macro_direction=macro_direction,
)
logger.info(
"Macro modifier for %s: %.4f (avg_impact=%.4f, macro_dir=%s, company_dir=%s)",
ticker, modifier, avg_macro_impact, macro_direction, company_direction,
)
# Apply modifier to company signals by scaling their impact scores
# We create modified copies rather than mutating originals
from copy import copy
modified_signals = []
for sig in company_signals:
new_sig = copy(sig)
new_sig.impact_score = sig.impact_score * modifier
modified_signals.append(new_sig)
return modified_signals, modifier
if has_macro and not has_company:
# Macro-only fallback: additive behavior with weight 0.3 (Req 11.3)
logger.info(
"Macro-only fallback for %s: using additive merge with weight %.2f",
ticker, macro_signal_weight,
)
return list(macro_signals), 1.0
# Company-only: no modification (Req 11.4)
logger.info("Company-only signals for %s: macro modifier=1.0", ticker)
return list(company_signals), 1.0
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
async def persist_macro_impact_record(
pool: asyncpg.Pool,
record: MacroImpactRecord,
) -> str:
"""Persist a MacroImpactRecord to the macro_impact_records table.
Returns the row UUID.
"""
row_id = await pool.fetchval(
"""INSERT INTO macro_impact_records
(event_id, company_id, ticker, macro_impact_score,
impact_direction, contributing_factors, confidence, computed_at)
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6::jsonb, $7, $8)
RETURNING id""",
record.event_id,
record.company_id,
record.ticker,
record.macro_impact_score,
record.impact_direction,
json.dumps(record.contributing_factors),
record.confidence,
record.computed_at,
)
logger.info(
"Persisted macro impact record for event=%s company=%s score=%.4f direction=%s",
record.event_id,
record.company_id,
record.macro_impact_score,
record.impact_direction,
)
return str(row_id)
async def persist_macro_impact_records(
pool: asyncpg.Pool,
records: list[MacroImpactRecord],
) -> list[str]:
"""Persist multiple MacroImpactRecords. Returns list of row UUIDs."""
ids: list[str] = []
for record in records:
if record.macro_impact_score > 0.0:
row_id = await persist_macro_impact_record(pool, record)
ids.append(row_id)
return ids
# ---------------------------------------------------------------------------
# Low-confidence event exclusion (Requirements: 10.1)
# ---------------------------------------------------------------------------
def filter_low_confidence_events(
events: list[GlobalEvent],
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
) -> list[GlobalEvent]:
"""Filter out events with confidence below the configurable threshold.
Events with confidence below the threshold are excluded from macro
impact computation and the exclusion reason is logged.
Args:
events: List of GlobalEvent classifications.
confidence_threshold: Minimum confidence for inclusion (default 0.4).
Returns:
List of events that pass the confidence threshold.
Requirements: 10.1
"""
included: list[GlobalEvent] = []
for event in events:
if event.confidence < confidence_threshold:
logger.info(
"Excluding low-confidence event %s: confidence=%.3f < threshold=%.3f",
event.event_id,
event.confidence,
confidence_threshold,
)
else:
included.append(event)
return included
# ---------------------------------------------------------------------------
# Accelerated decay for stale short-term events (Requirements: 10.2)
# ---------------------------------------------------------------------------
def compute_standard_recency_decay(
age_hours: float,
half_life_hours: float = 168.0,
) -> float:
"""Compute standard exponential recency decay.
Args:
age_hours: Age of the event in hours.
half_life_hours: Half-life for the decay function (default 7 days).
Returns:
Decay factor in (0, 1].
"""
import math
if age_hours <= 0:
return 1.0
return math.exp(-0.693 * age_hours / half_life_hours)
def apply_accelerated_decay(
age_hours: float,
estimated_duration: str,
staleness_hours: float = DEFAULT_SHORT_TERM_STALENESS_HOURS,
half_life_hours: float = 168.0,
) -> float:
"""Apply accelerated decay for stale short-term events.
For short_term events older than staleness_hours (default 48h),
the effective weight is strictly less than standard recency decay.
For non-short-term events or events within the staleness window,
standard recency decay is applied.
Args:
age_hours: Age of the event in hours.
estimated_duration: Event's estimated_duration field.
staleness_hours: Hours after which short-term events get accelerated decay.
half_life_hours: Half-life for the standard decay function.
Returns:
Effective signal weight in (0, 1].
Requirements: 10.2
"""
standard_decay = compute_standard_recency_decay(age_hours, half_life_hours)
if estimated_duration == "short_term" and age_hours > staleness_hours:
# Apply accelerated decay: multiply standard decay by a factor < 1
accelerated = standard_decay * ACCELERATED_DECAY_MULTIPLIER
logger.debug(
"Accelerated decay for short_term event: age=%.1fh, "
"standard=%.4f, accelerated=%.4f",
age_hours, standard_decay, accelerated,
)
return accelerated
return standard_decay