feat: competitive intelligence & historical pattern matching layer
This commit is contained in:
@@ -0,0 +1,741 @@
|
||||
"""Interpolation engine — macro-to-company impact scoring.
|
||||
|
||||
Computes per-company macro impact scores by evaluating overlap between
|
||||
global event classifications and company exposure profiles. Produces
|
||||
MacroImpactRecord objects that feed into the aggregation engine as
|
||||
additional weighted signals.
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 3.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.extractor.event_classifier import GlobalEvent
|
||||
from services.shared.schemas import (
|
||||
ExposureProfileSchema,
|
||||
MarketPositionTier,
|
||||
SeverityLevel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("interpolation")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default configuration constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_CONFIDENCE_THRESHOLD = 0.4
|
||||
DEFAULT_SHORT_TERM_STALENESS_HOURS = 48
|
||||
ACCELERATED_DECAY_MULTIPLIER = 0.5 # applied on top of standard recency decay
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Severity weights
|
||||
SEVERITY_WEIGHTS: dict[str, float] = {
|
||||
SeverityLevel.CRITICAL.value: 1.0,
|
||||
SeverityLevel.HIGH.value: 0.75,
|
||||
SeverityLevel.MODERATE.value: 0.5,
|
||||
SeverityLevel.LOW.value: 0.25,
|
||||
}
|
||||
|
||||
# Component weights in the scoring formula
|
||||
GEO_WEIGHT = 0.35
|
||||
SUPPLY_WEIGHT = 0.25
|
||||
COMMODITY_WEIGHT = 0.25
|
||||
SECTOR_WEIGHT = 0.15
|
||||
|
||||
# Resilience modifiers for international events
|
||||
RESILIENCE_MODIFIERS: dict[str, float] = {
|
||||
MarketPositionTier.GLOBAL_LEADER.value: 0.7,
|
||||
MarketPositionTier.MULTINATIONAL.value: 0.85,
|
||||
MarketPositionTier.REGIONAL.value: 1.0,
|
||||
MarketPositionTier.DOMESTIC.value: 1.2,
|
||||
}
|
||||
|
||||
# Event types that are typically negative
|
||||
_NEGATIVE_EVENT_TYPES = frozenset({
|
||||
"supply_disruption",
|
||||
"cost_increase",
|
||||
"regulatory_pressure",
|
||||
"geopolitical_risk",
|
||||
"trade_barrier",
|
||||
})
|
||||
|
||||
# Event types that can be positive
|
||||
_POSITIVE_EVENT_TYPES = frozenset({
|
||||
"demand_shift",
|
||||
})
|
||||
|
||||
# Event types that can go either way
|
||||
_AMBIGUOUS_EVENT_TYPES = frozenset({
|
||||
"commodity_shock",
|
||||
"currency_impact",
|
||||
})
|
||||
|
||||
# Market cap bucket → market position tier mapping for default profiles
|
||||
_CAP_TO_TIER: dict[str, str] = {
|
||||
"large_cap": MarketPositionTier.GLOBAL_LEADER.value,
|
||||
"mid_cap": MarketPositionTier.MULTINATIONAL.value,
|
||||
"small_cap": MarketPositionTier.REGIONAL.value,
|
||||
"micro_cap": MarketPositionTier.DOMESTIC.value,
|
||||
}
|
||||
|
||||
# Sector-based default geographic revenue mixes
|
||||
_SECTOR_DEFAULT_GEO: dict[str, dict[str, float]] = {
|
||||
"Information Technology": {"US": 0.45, "CN": 0.15, "EU": 0.15, "JP": 0.10, "KR": 0.15},
|
||||
"Health Care": {"US": 0.50, "EU": 0.25, "JP": 0.10, "CN": 0.15},
|
||||
"Financials": {"US": 0.55, "EU": 0.20, "GB": 0.15, "JP": 0.10},
|
||||
"Energy": {"US": 0.30, "SA": 0.20, "RU": 0.15, "CA": 0.15, "AE": 0.20},
|
||||
"Materials": {"US": 0.25, "CN": 0.25, "AU": 0.20, "BR": 0.15, "IN": 0.15},
|
||||
"Industrials": {"US": 0.40, "DE": 0.15, "CN": 0.15, "JP": 0.15, "KR": 0.15},
|
||||
"Consumer Discretionary": {"US": 0.45, "CN": 0.20, "EU": 0.15, "JP": 0.10, "IN": 0.10},
|
||||
"Consumer Staples": {"US": 0.45, "EU": 0.20, "CN": 0.15, "IN": 0.10, "BR": 0.10},
|
||||
"Communication Services": {"US": 0.50, "CN": 0.15, "EU": 0.15, "JP": 0.10, "IN": 0.10},
|
||||
"Utilities": {"US": 0.70, "EU": 0.15, "JP": 0.15},
|
||||
"Real Estate": {"US": 0.60, "CN": 0.15, "EU": 0.15, "JP": 0.10},
|
||||
}
|
||||
|
||||
_DEFAULT_GEO = {"US": 0.50, "EU": 0.20, "CN": 0.15, "JP": 0.15}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MacroImpactRecord dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MacroImpactRecord:
|
||||
"""A computed macro impact score for a specific company-event pair."""
|
||||
|
||||
event_id: str = ""
|
||||
company_id: str = ""
|
||||
ticker: str = ""
|
||||
macro_impact_score: float = 0.0 # [0, 1]
|
||||
impact_direction: str = "neutral" # positive|negative|mixed
|
||||
contributing_factors: list[str] = field(default_factory=list)
|
||||
confidence: float = 0.5 # [0, 1]
|
||||
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlap computation functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_geographic_overlap(
|
||||
event_regions: list[str],
|
||||
revenue_mix: dict[str, float],
|
||||
) -> float:
|
||||
"""Compute geographic overlap using revenue percentage weighting.
|
||||
|
||||
For each event region that appears in the company's revenue mix,
|
||||
sum the revenue percentage. Returns a value in [0, 1].
|
||||
|
||||
Args:
|
||||
event_regions: Region codes from the global event.
|
||||
revenue_mix: Company's geographic_revenue_mix (region -> pct).
|
||||
|
||||
Returns:
|
||||
Sum of revenue percentages for overlapping regions, clamped to [0, 1].
|
||||
"""
|
||||
if not event_regions or not revenue_mix:
|
||||
return 0.0
|
||||
|
||||
event_set = {r.upper() for r in event_regions}
|
||||
overlap = 0.0
|
||||
for region, pct in revenue_mix.items():
|
||||
if region.upper() in event_set:
|
||||
overlap += pct
|
||||
|
||||
return min(max(overlap, 0.0), 1.0)
|
||||
|
||||
|
||||
def compute_supply_chain_overlap(
|
||||
event_regions: list[str],
|
||||
supply_regions: list[str],
|
||||
) -> float:
|
||||
"""Compute supply chain overlap using set intersection ratio.
|
||||
|
||||
Returns the fraction of the company's supply chain regions that
|
||||
overlap with the event's affected regions.
|
||||
|
||||
Args:
|
||||
event_regions: Region codes from the global event.
|
||||
supply_regions: Company's supply_chain_regions.
|
||||
|
||||
Returns:
|
||||
Intersection ratio in [0, 1]. 0.0 if supply_regions is empty.
|
||||
"""
|
||||
if not event_regions or not supply_regions:
|
||||
return 0.0
|
||||
|
||||
event_set = {r.upper() for r in event_regions}
|
||||
supply_set = {r.upper() for r in supply_regions}
|
||||
|
||||
intersection = event_set & supply_set
|
||||
return len(intersection) / len(supply_set)
|
||||
|
||||
|
||||
def compute_commodity_overlap(
|
||||
event_commodities: list[str],
|
||||
company_commodities: list[str],
|
||||
) -> float:
|
||||
"""Compute commodity overlap using set intersection ratio.
|
||||
|
||||
Returns the fraction of the company's key commodities that overlap
|
||||
with the event's affected commodities.
|
||||
|
||||
Args:
|
||||
event_commodities: Commodity identifiers from the global event.
|
||||
company_commodities: Company's key_input_commodities.
|
||||
|
||||
Returns:
|
||||
Intersection ratio in [0, 1]. 0.0 if company_commodities is empty.
|
||||
"""
|
||||
if not event_commodities or not company_commodities:
|
||||
return 0.0
|
||||
|
||||
event_set = {c.lower() for c in event_commodities}
|
||||
company_set = {c.lower() for c in company_commodities}
|
||||
|
||||
intersection = event_set & company_set
|
||||
return len(intersection) / len(company_set)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resilience modifier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_resilience_modifier(
|
||||
raw_score: float,
|
||||
tier: str,
|
||||
event_is_international: bool = True,
|
||||
) -> float:
|
||||
"""Apply a resilience modifier based on market position tier.
|
||||
|
||||
For international events, global leaders get a dampening factor (0.7)
|
||||
while domestic companies get an amplification factor (1.2).
|
||||
For domestic-only events, no modifier is applied.
|
||||
|
||||
Args:
|
||||
raw_score: The raw impact score before resilience adjustment.
|
||||
tier: Market position tier value.
|
||||
event_is_international: Whether the event affects multiple countries.
|
||||
|
||||
Returns:
|
||||
Modified score clamped to [0, 1].
|
||||
"""
|
||||
if not event_is_international:
|
||||
return min(max(raw_score, 0.0), 1.0)
|
||||
|
||||
modifier = RESILIENCE_MODIFIERS.get(tier, 1.0)
|
||||
return min(max(raw_score * modifier, 0.0), 1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Impact direction determination
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _determine_impact_direction(
|
||||
event_types: list[str],
|
||||
) -> tuple[str, list[str], list[str]]:
|
||||
"""Determine impact direction from event types.
|
||||
|
||||
Returns:
|
||||
Tuple of (direction, positive_factors, negative_factors).
|
||||
"""
|
||||
positive_factors: list[str] = []
|
||||
negative_factors: list[str] = []
|
||||
|
||||
for et in event_types:
|
||||
if et in _NEGATIVE_EVENT_TYPES:
|
||||
negative_factors.append(et)
|
||||
elif et in _POSITIVE_EVENT_TYPES:
|
||||
positive_factors.append(et)
|
||||
elif et in _AMBIGUOUS_EVENT_TYPES:
|
||||
# Ambiguous types contribute to both sides
|
||||
positive_factors.append(et)
|
||||
negative_factors.append(et)
|
||||
|
||||
has_positive = len(positive_factors) > 0
|
||||
has_negative = len(negative_factors) > 0
|
||||
|
||||
if has_positive and has_negative:
|
||||
return "mixed", positive_factors, negative_factors
|
||||
elif has_positive:
|
||||
return "positive", positive_factors, negative_factors
|
||||
elif has_negative:
|
||||
return "negative", positive_factors, negative_factors
|
||||
else:
|
||||
return "negative", positive_factors, negative_factors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core scoring function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_macro_impact(
|
||||
event: GlobalEvent,
|
||||
profile: ExposureProfileSchema,
|
||||
) -> MacroImpactRecord:
|
||||
"""Compute the macro impact of a global event on a company.
|
||||
|
||||
Scoring formula:
|
||||
raw_score = severity_weight * (
|
||||
0.35 * geographic_overlap +
|
||||
0.25 * supply_chain_overlap +
|
||||
0.25 * commodity_overlap +
|
||||
0.15 * sector_match
|
||||
)
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
Args:
|
||||
event: The classified global event.
|
||||
profile: The company's exposure profile.
|
||||
|
||||
Returns:
|
||||
A MacroImpactRecord with the computed score and metadata.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Compute overlaps
|
||||
geo_overlap = compute_geographic_overlap(
|
||||
event.affected_regions,
|
||||
profile.geographic_revenue_mix,
|
||||
)
|
||||
supply_overlap = compute_supply_chain_overlap(
|
||||
event.affected_regions,
|
||||
profile.supply_chain_regions,
|
||||
)
|
||||
commodity_overlap = compute_commodity_overlap(
|
||||
event.affected_commodities,
|
||||
profile.key_input_commodities,
|
||||
)
|
||||
|
||||
# Sector match: 1.0 if any event sector matches the company's sector
|
||||
# We check against the profile's regulatory_jurisdictions as a proxy,
|
||||
# but the real sector comes from the company data. For now, we use
|
||||
# a simple heuristic: check if any affected_sectors appear in the
|
||||
# profile's geographic_revenue_mix keys or supply_chain_regions.
|
||||
# The actual sector is not stored in ExposureProfileSchema, so we
|
||||
# check if any event sectors match. This will be 0.0 unless the
|
||||
# caller provides sector info through contributing_factors.
|
||||
sector_match = 0.0
|
||||
# We'll compute sector_match based on event sectors — the caller
|
||||
# should ensure the profile has relevant sector info. For the
|
||||
# default implementation, we always set sector_match to 0.0 here
|
||||
# and let the caller override if needed.
|
||||
|
||||
# Check zero-overlap case
|
||||
contributing = []
|
||||
if geo_overlap > 0:
|
||||
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
|
||||
if supply_overlap > 0:
|
||||
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
|
||||
if commodity_overlap > 0:
|
||||
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
|
||||
|
||||
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
|
||||
if total_overlap == 0.0:
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=0.0,
|
||||
impact_direction="neutral",
|
||||
contributing_factors=[],
|
||||
confidence=0.0,
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
# Severity weight
|
||||
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
|
||||
|
||||
# Raw score
|
||||
raw_score = severity_weight * (
|
||||
GEO_WEIGHT * geo_overlap
|
||||
+ SUPPLY_WEIGHT * supply_overlap
|
||||
+ COMMODITY_WEIGHT * commodity_overlap
|
||||
+ SECTOR_WEIGHT * sector_match
|
||||
)
|
||||
|
||||
# Determine if event is international (affects multiple regions)
|
||||
is_international = len(event.affected_regions) > 1
|
||||
|
||||
# Apply resilience modifier
|
||||
tier = profile.market_position_tier
|
||||
if isinstance(tier, MarketPositionTier):
|
||||
tier = tier.value
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
# Determine impact direction
|
||||
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
|
||||
|
||||
# Build contributing factors list
|
||||
all_factors = list(contributing)
|
||||
if pos_factors:
|
||||
all_factors.append(f"positive_types:{','.join(pos_factors)}")
|
||||
if neg_factors:
|
||||
all_factors.append(f"negative_types:{','.join(neg_factors)}")
|
||||
|
||||
# Confidence: combine event confidence with overlap strength
|
||||
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
|
||||
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=round(min(final_score, 1.0), 6),
|
||||
impact_direction=direction,
|
||||
contributing_factors=all_factors,
|
||||
confidence=round(confidence, 6),
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
|
||||
def compute_macro_impact_with_sector(
|
||||
event: GlobalEvent,
|
||||
profile: ExposureProfileSchema,
|
||||
company_sector: str = "",
|
||||
) -> MacroImpactRecord:
|
||||
"""Compute macro impact with explicit sector matching.
|
||||
|
||||
Like compute_macro_impact but accepts a company_sector parameter
|
||||
for proper sector_match computation.
|
||||
|
||||
Args:
|
||||
event: The classified global event.
|
||||
profile: The company's exposure profile.
|
||||
company_sector: The company's GICS sector name.
|
||||
|
||||
Returns:
|
||||
A MacroImpactRecord with the computed score and metadata.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Compute overlaps
|
||||
geo_overlap = compute_geographic_overlap(
|
||||
event.affected_regions,
|
||||
profile.geographic_revenue_mix,
|
||||
)
|
||||
supply_overlap = compute_supply_chain_overlap(
|
||||
event.affected_regions,
|
||||
profile.supply_chain_regions,
|
||||
)
|
||||
commodity_overlap = compute_commodity_overlap(
|
||||
event.affected_commodities,
|
||||
profile.key_input_commodities,
|
||||
)
|
||||
|
||||
# Sector match
|
||||
sector_match = 0.0
|
||||
if company_sector and event.affected_sectors:
|
||||
company_sector_lower = company_sector.lower().strip()
|
||||
for es in event.affected_sectors:
|
||||
if es.lower().strip() == company_sector_lower:
|
||||
sector_match = 1.0
|
||||
break
|
||||
|
||||
# Contributing factors
|
||||
contributing: list[str] = []
|
||||
if geo_overlap > 0:
|
||||
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
|
||||
if supply_overlap > 0:
|
||||
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
|
||||
if commodity_overlap > 0:
|
||||
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
|
||||
if sector_match > 0:
|
||||
contributing.append(f"sector_match:{company_sector}")
|
||||
|
||||
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
|
||||
if total_overlap == 0.0:
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=0.0,
|
||||
impact_direction="neutral",
|
||||
contributing_factors=[],
|
||||
confidence=0.0,
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
# Severity weight
|
||||
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
|
||||
|
||||
# Raw score
|
||||
raw_score = severity_weight * (
|
||||
GEO_WEIGHT * geo_overlap
|
||||
+ SUPPLY_WEIGHT * supply_overlap
|
||||
+ COMMODITY_WEIGHT * commodity_overlap
|
||||
+ SECTOR_WEIGHT * sector_match
|
||||
)
|
||||
|
||||
# International check
|
||||
is_international = len(event.affected_regions) > 1
|
||||
|
||||
# Resilience modifier
|
||||
tier = profile.market_position_tier
|
||||
if isinstance(tier, MarketPositionTier):
|
||||
tier = tier.value
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
# Direction
|
||||
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
|
||||
|
||||
all_factors = list(contributing)
|
||||
if pos_factors:
|
||||
all_factors.append(f"positive_types:{','.join(pos_factors)}")
|
||||
if neg_factors:
|
||||
all_factors.append(f"negative_types:{','.join(neg_factors)}")
|
||||
|
||||
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
|
||||
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=round(min(final_score, 1.0), 6),
|
||||
impact_direction=direction,
|
||||
contributing_factors=all_factors,
|
||||
confidence=round(confidence, 6),
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default profile builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_default_profile(
|
||||
sector: str,
|
||||
industry: str,
|
||||
market_cap_bucket: str,
|
||||
) -> ExposureProfileSchema:
|
||||
"""Build a default ExposureProfile for companies without manual profiles.
|
||||
|
||||
Uses sector-based geographic revenue defaults and maps market_cap_bucket
|
||||
to market_position_tier:
|
||||
large_cap → global_leader
|
||||
mid_cap → multinational
|
||||
small_cap → regional
|
||||
micro_cap → domestic
|
||||
|
||||
Args:
|
||||
sector: GICS sector name.
|
||||
industry: Industry name (used for commodity defaults).
|
||||
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
|
||||
|
||||
Returns:
|
||||
An ExposureProfileSchema with source='inferred'.
|
||||
"""
|
||||
tier = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
|
||||
geo_mix = _SECTOR_DEFAULT_GEO.get(sector, _DEFAULT_GEO)
|
||||
|
||||
# Derive supply chain regions from geo mix keys
|
||||
supply_regions = list(geo_mix.keys())
|
||||
|
||||
# Derive commodities from sector/industry
|
||||
commodities = _infer_commodities(sector, industry)
|
||||
|
||||
# Export dependency based on tier
|
||||
export_pct = {
|
||||
MarketPositionTier.GLOBAL_LEADER.value: 0.5,
|
||||
MarketPositionTier.MULTINATIONAL.value: 0.35,
|
||||
MarketPositionTier.REGIONAL.value: 0.15,
|
||||
MarketPositionTier.DOMESTIC.value: 0.05,
|
||||
}.get(tier, 0.15)
|
||||
|
||||
return ExposureProfileSchema(
|
||||
company_id="",
|
||||
geographic_revenue_mix=dict(geo_mix),
|
||||
supply_chain_regions=supply_regions,
|
||||
key_input_commodities=commodities,
|
||||
regulatory_jurisdictions=list(geo_mix.keys())[:3],
|
||||
market_position_tier=MarketPositionTier(tier),
|
||||
export_dependency_pct=export_pct,
|
||||
source="inferred",
|
||||
confidence=0.5,
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
def _infer_commodities(sector: str, industry: str) -> list[str]:
|
||||
"""Infer key input commodities from sector and industry."""
|
||||
sector_commodities: dict[str, list[str]] = {
|
||||
"Energy": ["crude_oil", "natural_gas"],
|
||||
"Materials": ["copper", "steel", "lithium"],
|
||||
"Industrials": ["steel", "copper"],
|
||||
"Information Technology": ["semiconductors", "lithium"],
|
||||
"Consumer Staples": ["wheat", "corn"],
|
||||
"Consumer Discretionary": ["steel", "semiconductors"],
|
||||
"Health Care": [],
|
||||
"Financials": [],
|
||||
"Communication Services": [],
|
||||
"Utilities": ["natural_gas"],
|
||||
"Real Estate": ["steel"],
|
||||
}
|
||||
return sector_commodities.get(sector, [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def persist_macro_impact_record(
|
||||
pool: asyncpg.Pool,
|
||||
record: MacroImpactRecord,
|
||||
) -> str:
|
||||
"""Persist a MacroImpactRecord to the macro_impact_records table.
|
||||
|
||||
Returns the row UUID.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO macro_impact_records
|
||||
(event_id, company_id, ticker, macro_impact_score,
|
||||
impact_direction, contributing_factors, confidence, computed_at)
|
||||
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6::jsonb, $7, $8)
|
||||
RETURNING id""",
|
||||
record.event_id,
|
||||
record.company_id,
|
||||
record.ticker,
|
||||
record.macro_impact_score,
|
||||
record.impact_direction,
|
||||
json.dumps(record.contributing_factors),
|
||||
record.confidence,
|
||||
record.computed_at,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted macro impact record for event=%s company=%s score=%.4f direction=%s",
|
||||
record.event_id,
|
||||
record.company_id,
|
||||
record.macro_impact_score,
|
||||
record.impact_direction,
|
||||
)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
async def persist_macro_impact_records(
|
||||
pool: asyncpg.Pool,
|
||||
records: list[MacroImpactRecord],
|
||||
) -> list[str]:
|
||||
"""Persist multiple MacroImpactRecords. Returns list of row UUIDs."""
|
||||
ids: list[str] = []
|
||||
for record in records:
|
||||
if record.macro_impact_score > 0.0:
|
||||
row_id = await persist_macro_impact_record(pool, record)
|
||||
ids.append(row_id)
|
||||
return ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-confidence event exclusion (Requirements: 10.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def filter_low_confidence_events(
|
||||
events: list[GlobalEvent],
|
||||
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
|
||||
) -> list[GlobalEvent]:
|
||||
"""Filter out events with confidence below the configurable threshold.
|
||||
|
||||
Events with confidence below the threshold are excluded from macro
|
||||
impact computation and the exclusion reason is logged.
|
||||
|
||||
Args:
|
||||
events: List of GlobalEvent classifications.
|
||||
confidence_threshold: Minimum confidence for inclusion (default 0.4).
|
||||
|
||||
Returns:
|
||||
List of events that pass the confidence threshold.
|
||||
|
||||
Requirements: 10.1
|
||||
"""
|
||||
included: list[GlobalEvent] = []
|
||||
for event in events:
|
||||
if event.confidence < confidence_threshold:
|
||||
logger.info(
|
||||
"Excluding low-confidence event %s: confidence=%.3f < threshold=%.3f",
|
||||
event.event_id,
|
||||
event.confidence,
|
||||
confidence_threshold,
|
||||
)
|
||||
else:
|
||||
included.append(event)
|
||||
return included
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Accelerated decay for stale short-term events (Requirements: 10.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_standard_recency_decay(
|
||||
age_hours: float,
|
||||
half_life_hours: float = 168.0,
|
||||
) -> float:
|
||||
"""Compute standard exponential recency decay.
|
||||
|
||||
Args:
|
||||
age_hours: Age of the event in hours.
|
||||
half_life_hours: Half-life for the decay function (default 7 days).
|
||||
|
||||
Returns:
|
||||
Decay factor in (0, 1].
|
||||
"""
|
||||
import math
|
||||
if age_hours <= 0:
|
||||
return 1.0
|
||||
return math.exp(-0.693 * age_hours / half_life_hours)
|
||||
|
||||
|
||||
def apply_accelerated_decay(
|
||||
age_hours: float,
|
||||
estimated_duration: str,
|
||||
staleness_hours: float = DEFAULT_SHORT_TERM_STALENESS_HOURS,
|
||||
half_life_hours: float = 168.0,
|
||||
) -> float:
|
||||
"""Apply accelerated decay for stale short-term events.
|
||||
|
||||
For short_term events older than staleness_hours (default 48h),
|
||||
the effective weight is strictly less than standard recency decay.
|
||||
|
||||
For non-short-term events or events within the staleness window,
|
||||
standard recency decay is applied.
|
||||
|
||||
Args:
|
||||
age_hours: Age of the event in hours.
|
||||
estimated_duration: Event's estimated_duration field.
|
||||
staleness_hours: Hours after which short-term events get accelerated decay.
|
||||
half_life_hours: Half-life for the standard decay function.
|
||||
|
||||
Returns:
|
||||
Effective signal weight in (0, 1].
|
||||
|
||||
Requirements: 10.2
|
||||
"""
|
||||
standard_decay = compute_standard_recency_decay(age_hours, half_life_hours)
|
||||
|
||||
if estimated_duration == "short_term" and age_hours > staleness_hours:
|
||||
# Apply accelerated decay: multiply standard decay by a factor < 1
|
||||
accelerated = standard_decay * ACCELERATED_DECAY_MULTIPLIER
|
||||
logger.debug(
|
||||
"Accelerated decay for short_term event: age=%.1fh, "
|
||||
"standard=%.4f, accelerated=%.4f",
|
||||
age_hours, standard_decay, accelerated,
|
||||
)
|
||||
return accelerated
|
||||
|
||||
return standard_decay
|
||||
@@ -1,4 +1,12 @@
|
||||
"""Aggregation worker entrypoint - polls Redis for aggregation jobs."""
|
||||
"""Aggregation worker entrypoint - polls Redis for aggregation jobs.
|
||||
|
||||
After computing trend summaries for a ticker, the worker also triggers
|
||||
competitive signal propagation for the ticker's competitors when the
|
||||
competitive layer is enabled. This ensures that document intelligence
|
||||
for one company produces competitive signals for related companies.
|
||||
|
||||
Requirements: 4.1, 5.1, 9.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -8,8 +16,9 @@ import logging
|
||||
import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from services.aggregation.worker import aggregate_company
|
||||
from services.shared.config import load_config
|
||||
from services.aggregation.signal_propagation import propagate_signals
|
||||
from services.aggregation.worker import aggregate_company, fetch_competitive_enabled
|
||||
from services.shared.config import CompetitiveConfig, load_config
|
||||
from services.shared.logging import inject_trace_context, setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_AGGREGATION,
|
||||
@@ -20,6 +29,92 @@ from services.shared.redis_keys import (
|
||||
logger = logging.getLogger("aggregation_main")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query to fetch recent document intelligence records for a ticker.
|
||||
# Used to trigger signal propagation after aggregation completes.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_RECENT_INTELLIGENCE_QUERY = """
|
||||
SELECT
|
||||
di.document_id,
|
||||
dir.catalyst_type,
|
||||
dir.impact_score
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.id = dir.intelligence_id
|
||||
JOIN documents d ON d.id = di.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
ORDER BY d.published_at DESC
|
||||
LIMIT 50
|
||||
"""
|
||||
|
||||
|
||||
# Track consecutive propagation failures for alerting (Requirement 9.4)
|
||||
_propagation_consecutive_failures = 0
|
||||
|
||||
|
||||
async def _trigger_signal_propagation(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
competitive_config: CompetitiveConfig,
|
||||
) -> int:
|
||||
"""Trigger competitive signal propagation for a ticker's recent documents.
|
||||
|
||||
Fetches recent document intelligence records for the ticker and calls
|
||||
propagate_signals for each, producing competitive signals for the
|
||||
ticker's competitors.
|
||||
|
||||
Returns the total number of competitive signals produced.
|
||||
"""
|
||||
global _propagation_consecutive_failures
|
||||
|
||||
rows = await pool.fetch(_RECENT_INTELLIGENCE_QUERY, ticker)
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
total_signals = 0
|
||||
for row in rows:
|
||||
document_id = str(row["document_id"])
|
||||
catalyst_type = row["catalyst_type"] or "other"
|
||||
impact_score = float(row["impact_score"] or 0.0)
|
||||
|
||||
if impact_score <= 0.0:
|
||||
continue
|
||||
|
||||
try:
|
||||
records = await propagate_signals(
|
||||
pool=pool,
|
||||
ticker=ticker,
|
||||
catalyst_type=catalyst_type,
|
||||
impact_score=impact_score,
|
||||
document_id=document_id,
|
||||
config=competitive_config,
|
||||
)
|
||||
total_signals += len(records)
|
||||
|
||||
# Reset failure counter on success
|
||||
_propagation_consecutive_failures = 0
|
||||
|
||||
except Exception:
|
||||
_propagation_consecutive_failures += 1
|
||||
logger.exception(
|
||||
"Signal propagation failed for %s doc %s/%s",
|
||||
ticker, document_id, catalyst_type,
|
||||
)
|
||||
if _propagation_consecutive_failures >= competitive_config.propagation_failure_threshold:
|
||||
logger.critical(
|
||||
"ALERT: Sustained signal propagation failures (%d consecutive). "
|
||||
"Continuing with company-specific + macro signals only. "
|
||||
"Operator action required.",
|
||||
_propagation_consecutive_failures,
|
||||
)
|
||||
# Stop trying propagation for this ticker after threshold
|
||||
break
|
||||
|
||||
return total_signals
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("aggregation", level=config.log_level, json_output=config.json_logs)
|
||||
@@ -28,6 +123,7 @@ async def main() -> None:
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_AGGREGATION)
|
||||
rec_queue = queue_key(QUEUE_RECOMMENDATION)
|
||||
competitive_config = config.competitive
|
||||
logger.info("Aggregation worker started, polling %s", queue)
|
||||
|
||||
try:
|
||||
@@ -49,6 +145,32 @@ async def main() -> None:
|
||||
ticker, len(summaries),
|
||||
)
|
||||
|
||||
# Trigger competitive signal propagation after aggregation
|
||||
# (Requirement 4.1): When new document intelligence is
|
||||
# produced for a company, propagate signals to competitors.
|
||||
# Check toggle state from DB (same pattern as macro toggle).
|
||||
competitive_enabled = competitive_config.competitive_enabled
|
||||
db_toggle = await fetch_competitive_enabled(pool)
|
||||
if db_toggle is not None:
|
||||
competitive_enabled = db_toggle
|
||||
|
||||
if competitive_enabled:
|
||||
try:
|
||||
sig_count = await _trigger_signal_propagation(
|
||||
pool, ticker, competitive_config,
|
||||
)
|
||||
if sig_count > 0:
|
||||
logger.info(
|
||||
"Propagated %d competitive signals for %s",
|
||||
sig_count, ticker,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Signal propagation failed for %s — "
|
||||
"continuing with company+macro signals only",
|
||||
ticker,
|
||||
)
|
||||
|
||||
# Enqueue recommendation job for each window that produced a trend
|
||||
for summary in summaries:
|
||||
if summary.trend_strength > 0:
|
||||
|
||||
@@ -0,0 +1,414 @@
|
||||
"""Historical pattern mining for competitive intelligence.
|
||||
|
||||
Queries document_impact_records joined with trend_windows to find how
|
||||
similar catalyst types resolved historically for a company or its
|
||||
competitors. Produces HistoricalPattern objects consumed by the signal
|
||||
propagation engine and the aggregation worker.
|
||||
|
||||
Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 11.1, 11.2, 11.3, 11.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.config import CompetitiveConfig
|
||||
from services.shared.schemas import MAJOR_DECISION_CATALYSTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HORIZONS = ["1d", "7d", "30d"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class HistoricalPattern:
|
||||
"""Statistical summary of how a catalyst type resolved historically."""
|
||||
|
||||
source_ticker: str
|
||||
target_ticker: str
|
||||
catalyst_type: str
|
||||
time_horizon: str # 1d | 7d | 30d
|
||||
sample_count: int
|
||||
bullish_pct: float # [0, 1]
|
||||
bearish_pct: float # [0, 1]
|
||||
avg_strength: float # [0, 1]
|
||||
avg_time_to_resolution: float # days
|
||||
pattern_confidence: float # [0, 1]
|
||||
data_start: datetime
|
||||
data_end: datetime
|
||||
tier: str # major_corporate_decision | routine_signal
|
||||
insufficient_data: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Catalyst tier classification (Req 11.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def classify_catalyst_tier(catalyst_type: str) -> str:
|
||||
"""Deterministic mapping of catalyst_type to tier.
|
||||
|
||||
Returns ``"major_corporate_decision"`` for catalyst types in
|
||||
MAJOR_DECISION_CATALYSTS, otherwise ``"routine_signal"``.
|
||||
"""
|
||||
if catalyst_type in MAJOR_DECISION_CATALYSTS:
|
||||
return "major_corporate_decision"
|
||||
return "routine_signal"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern confidence (Req 3.3, 11.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_pattern_confidence(
|
||||
sample_count: int,
|
||||
outcome_consistency: float,
|
||||
data_recency_days: float,
|
||||
tier: str,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> float:
|
||||
"""Compute pattern confidence score in [0, 1].
|
||||
|
||||
Formula:
|
||||
sample_factor * 0.4 + consistency * 0.4 + recency_factor * 0.2
|
||||
|
||||
With a 1.3× multiplier for ``major_corporate_decision`` tier,
|
||||
insufficient-data cap, and staleness decay.
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
|
||||
# --- component factors ---
|
||||
sample_factor = min(sample_count / 20.0, 1.0)
|
||||
consistency = outcome_consistency # already max(bullish_pct, bearish_pct)
|
||||
|
||||
if data_recency_days <= cfg.staleness_recent_days:
|
||||
recency_factor = 1.0
|
||||
elif data_recency_days <= cfg.staleness_window_days:
|
||||
recency_factor = 0.7
|
||||
else:
|
||||
recency_factor = 0.4
|
||||
|
||||
confidence = sample_factor * 0.4 + consistency * 0.4 + recency_factor * 0.2
|
||||
|
||||
# Major-decision multiplier (Req 11.2)
|
||||
if tier == "major_corporate_decision":
|
||||
confidence *= cfg.major_decision_weight_multiplier
|
||||
|
||||
# Clamp to [0, 1]
|
||||
confidence = min(max(confidence, 0.0), 1.0)
|
||||
|
||||
# Insufficient data cap (Req 3.4)
|
||||
if sample_count < cfg.min_pattern_samples:
|
||||
confidence = min(confidence, 0.25)
|
||||
|
||||
# Staleness decay (Req 9.2)
|
||||
if data_recency_days > cfg.staleness_window_days:
|
||||
confidence *= cfg.staleness_decay_penalty
|
||||
|
||||
return confidence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookback helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _lookback_days(tier: str, config: Optional[CompetitiveConfig] = None) -> int:
|
||||
"""Return the lookback window in days for the given tier."""
|
||||
cfg = config or CompetitiveConfig()
|
||||
if tier == "major_corporate_decision":
|
||||
return cfg.major_decision_lookback_days
|
||||
return cfg.routine_lookback_days
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL: self-company pattern query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SELF_PATTERN_QUERY = """
|
||||
WITH matched_docs AS (
|
||||
SELECT
|
||||
dir.id AS dir_id,
|
||||
d.published_at,
|
||||
dir.sentiment
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.id = dir.intelligence_id
|
||||
JOIN documents d ON d.id = di.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND dir.catalyst_type = $2
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
AND d.published_at >= $3
|
||||
AND d.published_at <= $4
|
||||
)
|
||||
SELECT
|
||||
md.dir_id,
|
||||
md.published_at,
|
||||
md.sentiment,
|
||||
tw.trend_direction,
|
||||
tw.trend_strength,
|
||||
tw.generated_at,
|
||||
tw."window" AS tw_window
|
||||
FROM matched_docs md
|
||||
JOIN trend_windows tw
|
||||
ON tw.entity_type = 'company'
|
||||
AND tw.entity_id = $1
|
||||
AND tw."window" = $5
|
||||
AND tw.generated_at >= md.published_at
|
||||
AND tw.generated_at <= md.published_at + $6::interval
|
||||
ORDER BY md.published_at DESC
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL: cross-company pattern query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CROSS_PATTERN_QUERY = """
|
||||
WITH matched_docs AS (
|
||||
SELECT
|
||||
dir.id AS dir_id,
|
||||
d.published_at,
|
||||
dir.sentiment
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.id = dir.intelligence_id
|
||||
JOIN documents d ON d.id = di.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND dir.catalyst_type = $2
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
AND d.published_at >= $3
|
||||
AND d.published_at <= $4
|
||||
)
|
||||
SELECT
|
||||
md.dir_id,
|
||||
md.published_at,
|
||||
md.sentiment,
|
||||
tw.trend_direction,
|
||||
tw.trend_strength,
|
||||
tw.generated_at,
|
||||
tw."window" AS tw_window
|
||||
FROM matched_docs md
|
||||
JOIN trend_windows tw
|
||||
ON tw.entity_type = 'company'
|
||||
AND tw.entity_id = $5
|
||||
AND tw."window" = $6
|
||||
AND tw.generated_at >= md.published_at
|
||||
AND tw.generated_at <= md.published_at + $7::interval
|
||||
ORDER BY md.published_at DESC
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Horizon → interval mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HORIZON_INTERVALS: dict[str, str] = {
|
||||
"1d": "1 day",
|
||||
"7d": "7 days",
|
||||
"30d": "30 days",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build HistoricalPattern from query rows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_pattern(
|
||||
rows: list[asyncpg.Record],
|
||||
source_ticker: str,
|
||||
target_ticker: str,
|
||||
catalyst_type: str,
|
||||
horizon: str,
|
||||
tier: str,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> Optional[HistoricalPattern]:
|
||||
"""Aggregate query rows into a single HistoricalPattern."""
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
# De-duplicate by dir_id — keep the first (closest) trend_window per doc
|
||||
seen: set[str] = set()
|
||||
unique_rows: list[asyncpg.Record] = []
|
||||
for r in rows:
|
||||
rid = str(r["dir_id"])
|
||||
if rid not in seen:
|
||||
seen.add(rid)
|
||||
unique_rows.append(r)
|
||||
|
||||
sample_count = len(unique_rows)
|
||||
|
||||
bullish = sum(1 for r in unique_rows if r["trend_direction"] == "bullish")
|
||||
bearish = sum(1 for r in unique_rows if r["trend_direction"] == "bearish")
|
||||
bullish_pct = bullish / sample_count
|
||||
bearish_pct = bearish / sample_count
|
||||
|
||||
strengths = [float(r["trend_strength"]) for r in unique_rows if r["trend_strength"] is not None]
|
||||
avg_strength = sum(strengths) / len(strengths) if strengths else 0.0
|
||||
|
||||
# avg_time_to_resolution: average days between published_at and generated_at
|
||||
resolutions: list[float] = []
|
||||
for r in unique_rows:
|
||||
pub = r["published_at"]
|
||||
gen = r["generated_at"]
|
||||
if pub and gen:
|
||||
delta = (gen - pub).total_seconds() / 86400.0
|
||||
resolutions.append(max(delta, 0.0))
|
||||
avg_time_to_resolution = sum(resolutions) / len(resolutions) if resolutions else 0.0
|
||||
|
||||
# Date range
|
||||
published_dates = [r["published_at"] for r in unique_rows if r["published_at"] is not None]
|
||||
data_start = min(published_dates)
|
||||
data_end = max(published_dates)
|
||||
|
||||
# Recency: days since the most recent data point
|
||||
now = datetime.now(timezone.utc)
|
||||
data_recency_days = (now - data_end).total_seconds() / 86400.0 if data_end else 999.0
|
||||
|
||||
outcome_consistency = max(bullish_pct, bearish_pct)
|
||||
confidence = compute_pattern_confidence(
|
||||
sample_count, outcome_consistency, data_recency_days, tier, config,
|
||||
)
|
||||
|
||||
insufficient_data = sample_count < (config or CompetitiveConfig()).min_pattern_samples
|
||||
|
||||
return HistoricalPattern(
|
||||
source_ticker=source_ticker,
|
||||
target_ticker=target_ticker,
|
||||
catalyst_type=catalyst_type,
|
||||
time_horizon=horizon,
|
||||
sample_count=sample_count,
|
||||
bullish_pct=bullish_pct,
|
||||
bearish_pct=bearish_pct,
|
||||
avg_strength=min(max(avg_strength, 0.0), 1.0),
|
||||
avg_time_to_resolution=avg_time_to_resolution,
|
||||
pattern_confidence=confidence,
|
||||
data_start=data_start,
|
||||
data_end=data_end,
|
||||
tier=tier,
|
||||
insufficient_data=insufficient_data,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def find_self_patterns(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
catalyst_type: str,
|
||||
horizons: Optional[list[str]] = None,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> list[HistoricalPattern]:
|
||||
"""Find historical patterns for the same company-catalyst pair.
|
||||
|
||||
Queries document_impact_records joined with trend_windows for the
|
||||
given ticker and catalyst_type across configurable time horizons.
|
||||
|
||||
Requirements: 3.1, 3.2, 3.5, 11.3
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
horizons = horizons or DEFAULT_HORIZONS
|
||||
tier = classify_catalyst_tier(catalyst_type)
|
||||
lookback = _lookback_days(tier, cfg)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff = now - timedelta(days=lookback)
|
||||
|
||||
patterns: list[HistoricalPattern] = []
|
||||
async with pool.acquire() as conn:
|
||||
for horizon in horizons:
|
||||
interval = _HORIZON_INTERVALS.get(horizon)
|
||||
if interval is None:
|
||||
logger.warning("Unknown horizon %s, skipping", horizon)
|
||||
continue
|
||||
try:
|
||||
rows = await conn.fetch(
|
||||
_SELF_PATTERN_QUERY,
|
||||
ticker, # $1
|
||||
catalyst_type, # $2
|
||||
cutoff, # $3
|
||||
now, # $4
|
||||
horizon, # $5
|
||||
interval, # $6
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error querying self-patterns for %s/%s/%s",
|
||||
ticker, catalyst_type, horizon,
|
||||
)
|
||||
continue
|
||||
|
||||
pattern = _build_pattern(
|
||||
rows, ticker, ticker, catalyst_type, horizon, tier, cfg,
|
||||
)
|
||||
if pattern is not None:
|
||||
patterns.append(pattern)
|
||||
|
||||
return patterns
|
||||
|
||||
|
||||
async def find_cross_company_patterns(
|
||||
pool: asyncpg.Pool,
|
||||
source_ticker: str,
|
||||
target_ticker: str,
|
||||
catalyst_type: str,
|
||||
horizons: Optional[list[str]] = None,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> list[HistoricalPattern]:
|
||||
"""Find cross-company historical patterns.
|
||||
|
||||
Queries documents about *source_ticker* with the given catalyst_type,
|
||||
then looks at trend_windows for *target_ticker* within each horizon
|
||||
after the document was published.
|
||||
|
||||
Requirements: 4.2, 11.5
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
horizons = horizons or DEFAULT_HORIZONS
|
||||
tier = classify_catalyst_tier(catalyst_type)
|
||||
lookback = _lookback_days(tier, cfg)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff = now - timedelta(days=lookback)
|
||||
|
||||
patterns: list[HistoricalPattern] = []
|
||||
async with pool.acquire() as conn:
|
||||
for horizon in horizons:
|
||||
interval = _HORIZON_INTERVALS.get(horizon)
|
||||
if interval is None:
|
||||
logger.warning("Unknown horizon %s, skipping", horizon)
|
||||
continue
|
||||
try:
|
||||
rows = await conn.fetch(
|
||||
_CROSS_PATTERN_QUERY,
|
||||
source_ticker, # $1
|
||||
catalyst_type, # $2
|
||||
cutoff, # $3
|
||||
now, # $4
|
||||
target_ticker, # $5
|
||||
horizon, # $6
|
||||
interval, # $7
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error querying cross-patterns for %s→%s/%s/%s",
|
||||
source_ticker, target_ticker, catalyst_type, horizon,
|
||||
)
|
||||
continue
|
||||
|
||||
pattern = _build_pattern(
|
||||
rows, source_ticker, target_ticker, catalyst_type,
|
||||
horizon, tier, cfg,
|
||||
)
|
||||
if pattern is not None:
|
||||
patterns.append(pattern)
|
||||
|
||||
return patterns
|
||||
@@ -0,0 +1,416 @@
|
||||
"""Trend projection module — forward-looking trend estimates.
|
||||
|
||||
Computes TrendProjection objects by combining current trend momentum,
|
||||
macro signal decay trajectories, and upcoming catalyst outlook.
|
||||
Projections are persisted alongside trend_window records.
|
||||
|
||||
Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.9
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.schemas import TrendDirection, TrendSummary
|
||||
|
||||
logger = logging.getLogger("projection")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrendProjection dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_DIRECTIONS = {"bullish", "bearish", "mixed", "neutral"}
|
||||
VALID_HORIZONS = {"1d", "7d", "30d"}
|
||||
|
||||
# Default low-confidence threshold
|
||||
DEFAULT_CONFIDENCE_THRESHOLD = 0.3
|
||||
|
||||
# Macro signal decay half-lives (in days) by estimated_duration
|
||||
DECAY_HALF_LIFE_DAYS: dict[str, float] = {
|
||||
"short_term": 1.0, # halve impact per day
|
||||
"medium_term": 7.0, # halve impact per week
|
||||
"long_term": 30.0, # halve impact per month
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrendProjection:
|
||||
"""Forward-looking trend projection for a company."""
|
||||
|
||||
projected_direction: str = "neutral" # bullish|bearish|mixed|neutral
|
||||
projected_strength: float = 0.5 # [0, 1]
|
||||
projected_confidence: float = 0.5 # [0, 1]
|
||||
projection_horizon: str = "7d" # 1d|7d|30d
|
||||
driving_factors: list[str] = field(default_factory=list)
|
||||
macro_contribution_pct: float = 0.0 # [0, 1]
|
||||
diverges_from_current: bool = False
|
||||
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
low_confidence: bool = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Macro impact row type (lightweight, avoids circular import with worker)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MacroEventInfo:
|
||||
"""Minimal macro event info needed for projection computation."""
|
||||
|
||||
event_id: str = ""
|
||||
macro_impact_score: float = 0.0
|
||||
impact_direction: str = "neutral"
|
||||
confidence: float = 0.5
|
||||
estimated_duration: str = "short_term"
|
||||
severity: str = "low"
|
||||
event_age_hours: float = 0.0 # hours since event publication
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projection horizon mapping from trend window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_WINDOW_TO_HORIZON: dict[str, str] = {
|
||||
"intraday": "1d",
|
||||
"1d": "1d",
|
||||
"7d": "7d",
|
||||
"30d": "30d",
|
||||
"90d": "30d",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Momentum computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_trend_momentum(
|
||||
current_strength: float,
|
||||
current_direction: str,
|
||||
previous_strength: float | None = None,
|
||||
previous_direction: str | None = None,
|
||||
) -> float:
|
||||
"""Compute trend momentum as rate of change in signed strength.
|
||||
|
||||
Returns a value in [-1, 1] representing the momentum:
|
||||
- Positive = strengthening bullish or weakening bearish
|
||||
- Negative = strengthening bearish or weakening bullish
|
||||
- Zero = no change or no previous data
|
||||
|
||||
When no previous data is available, uses a simple heuristic based
|
||||
on current strength and direction.
|
||||
"""
|
||||
dir_sign = _direction_sign(current_direction)
|
||||
|
||||
if previous_strength is None or previous_direction is None:
|
||||
# Heuristic: assume momentum proportional to current signed strength
|
||||
return round(dir_sign * current_strength * 0.5, 6)
|
||||
|
||||
prev_sign = _direction_sign(previous_direction)
|
||||
current_signed = dir_sign * current_strength
|
||||
previous_signed = prev_sign * previous_strength
|
||||
|
||||
momentum = current_signed - previous_signed
|
||||
return round(max(-1.0, min(1.0, momentum)), 6)
|
||||
|
||||
|
||||
def _direction_sign(direction: str) -> float:
|
||||
"""Map direction to a sign multiplier."""
|
||||
if direction == "bullish":
|
||||
return 1.0
|
||||
elif direction == "bearish":
|
||||
return -1.0
|
||||
return 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Macro signal decay projection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SEVERITY_WEIGHT: dict[str, float] = {
|
||||
"critical": 1.0,
|
||||
"high": 0.75,
|
||||
"moderate": 0.5,
|
||||
"low": 0.25,
|
||||
}
|
||||
|
||||
|
||||
def project_macro_decay(
|
||||
events: list[MacroEventInfo],
|
||||
horizon_days: float,
|
||||
) -> tuple[float, str]:
|
||||
"""Project the aggregate macro signal after decay over the horizon.
|
||||
|
||||
For each active macro event, compute the projected remaining impact
|
||||
using exponential decay based on estimated_duration:
|
||||
- short_term: half-life = 1 day
|
||||
- medium_term: half-life = 7 days
|
||||
- long_term: half-life = 30 days
|
||||
|
||||
Returns:
|
||||
(projected_macro_strength, projected_macro_direction)
|
||||
where strength is in [0, 1] and direction is bullish|bearish|mixed|neutral.
|
||||
"""
|
||||
if not events:
|
||||
return 0.0, "neutral"
|
||||
|
||||
positive_weight = 0.0
|
||||
negative_weight = 0.0
|
||||
|
||||
for ev in events:
|
||||
half_life = DECAY_HALF_LIFE_DAYS.get(ev.estimated_duration, 7.0)
|
||||
# Current age in days
|
||||
current_age_days = ev.event_age_hours / 24.0
|
||||
# Projected age at end of horizon
|
||||
future_age_days = current_age_days + horizon_days
|
||||
|
||||
# Decay factor: ratio of future impact to current impact
|
||||
if half_life > 0:
|
||||
current_factor = math.pow(2.0, -current_age_days / half_life)
|
||||
future_factor = math.pow(2.0, -future_age_days / half_life)
|
||||
else:
|
||||
current_factor = 0.0
|
||||
future_factor = 0.0
|
||||
|
||||
severity_w = _SEVERITY_WEIGHT.get(ev.severity, 0.25)
|
||||
projected_impact = ev.macro_impact_score * future_factor * severity_w
|
||||
|
||||
if ev.impact_direction == "positive":
|
||||
positive_weight += projected_impact
|
||||
elif ev.impact_direction == "negative":
|
||||
negative_weight += projected_impact
|
||||
else:
|
||||
# mixed/neutral: split evenly
|
||||
positive_weight += projected_impact * 0.5
|
||||
negative_weight += projected_impact * 0.5
|
||||
|
||||
total = positive_weight + negative_weight
|
||||
if total == 0.0:
|
||||
return 0.0, "neutral"
|
||||
|
||||
strength = min(total, 1.0)
|
||||
|
||||
if positive_weight > negative_weight * 1.2:
|
||||
direction = "bullish"
|
||||
elif negative_weight > positive_weight * 1.2:
|
||||
direction = "bearish"
|
||||
elif positive_weight > 0 and negative_weight > 0:
|
||||
direction = "mixed"
|
||||
else:
|
||||
direction = "neutral"
|
||||
|
||||
return round(strength, 6), direction
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Horizon days mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HORIZON_DAYS: dict[str, float] = {
|
||||
"1d": 1.0,
|
||||
"7d": 7.0,
|
||||
"30d": 30.0,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core projection computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_projection(
|
||||
summary: TrendSummary,
|
||||
macro_events: list[MacroEventInfo] | None = None,
|
||||
macro_enabled: bool = True,
|
||||
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
|
||||
previous_strength: float | None = None,
|
||||
previous_direction: str | None = None,
|
||||
upcoming_catalysts: list[str] | None = None,
|
||||
) -> TrendProjection:
|
||||
"""Compute a forward-looking trend projection.
|
||||
|
||||
Combines:
|
||||
1. Trend momentum (rate of change in strength)
|
||||
2. Macro signal decay projection
|
||||
3. Upcoming catalyst outlook
|
||||
4. Current trend baseline
|
||||
|
||||
Args:
|
||||
summary: The current trend summary.
|
||||
macro_events: Active macro events with their info.
|
||||
macro_enabled: Whether the macro layer is enabled.
|
||||
confidence_threshold: Below this, mark as low_confidence.
|
||||
previous_strength: Previous window's trend strength (optional).
|
||||
previous_direction: Previous window's trend direction (optional).
|
||||
upcoming_catalysts: Known upcoming catalysts from doc intelligence.
|
||||
|
||||
Returns:
|
||||
A TrendProjection with projected direction, strength, and confidence.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
current_dir = summary.trend_direction.value
|
||||
current_strength = summary.trend_strength
|
||||
current_confidence = summary.confidence
|
||||
|
||||
horizon = _WINDOW_TO_HORIZON.get(summary.window.value, "7d")
|
||||
horizon_days = _HORIZON_DAYS.get(horizon, 7.0)
|
||||
|
||||
driving_factors: list[str] = []
|
||||
|
||||
# 1. Compute trend momentum
|
||||
momentum = compute_trend_momentum(
|
||||
current_strength, current_dir,
|
||||
previous_strength, previous_direction,
|
||||
)
|
||||
if abs(momentum) > 0.05:
|
||||
if momentum > 0:
|
||||
driving_factors.append(f"Positive momentum ({momentum:+.3f}) in recent trend strength")
|
||||
else:
|
||||
driving_factors.append(f"Negative momentum ({momentum:+.3f}) in recent trend strength")
|
||||
|
||||
# 2. Project macro signal decay
|
||||
macro_strength = 0.0
|
||||
macro_direction = "neutral"
|
||||
macro_contribution = 0.0
|
||||
|
||||
if macro_enabled and macro_events:
|
||||
macro_strength, macro_direction = project_macro_decay(macro_events, horizon_days)
|
||||
if macro_strength > 0:
|
||||
driving_factors.append(
|
||||
f"Macro signals project {macro_direction} impact "
|
||||
f"(strength {macro_strength:.3f}) over {horizon}"
|
||||
)
|
||||
|
||||
# 3. Factor in upcoming catalysts
|
||||
catalysts = upcoming_catalysts or []
|
||||
for catalyst in catalysts[:3]: # limit to top 3
|
||||
driving_factors.append(f"Upcoming catalyst: {catalyst}")
|
||||
|
||||
catalyst_boost = min(len(catalysts) * 0.02, 0.1) # small boost per catalyst
|
||||
|
||||
# 4. Combine into projected direction/strength/confidence
|
||||
# Momentum-based projection of company-specific trend
|
||||
momentum_projected_signed = _direction_sign(current_dir) * current_strength + momentum * 0.5
|
||||
momentum_projected_strength = min(abs(momentum_projected_signed), 1.0)
|
||||
|
||||
if macro_enabled and macro_strength > 0:
|
||||
# Blend company momentum with macro trajectory
|
||||
macro_weight = min(macro_strength * 0.4, 0.4)
|
||||
company_weight = 1.0 - macro_weight
|
||||
|
||||
macro_signed = _direction_sign(macro_direction) * macro_strength
|
||||
blended_signed = (
|
||||
company_weight * momentum_projected_signed
|
||||
+ macro_weight * macro_signed
|
||||
)
|
||||
projected_strength = round(min(abs(blended_signed) + catalyst_boost, 1.0), 6)
|
||||
macro_contribution = round(macro_weight, 6)
|
||||
|
||||
# Determine projected direction from blended signal
|
||||
projected_direction = _signed_to_direction(blended_signed)
|
||||
else:
|
||||
# Company-only projection
|
||||
projected_strength = round(min(momentum_projected_strength + catalyst_boost, 1.0), 6)
|
||||
projected_direction = _signed_to_direction(momentum_projected_signed)
|
||||
|
||||
# Compute projected confidence
|
||||
base_confidence = current_confidence * 0.8 # projection inherently less certain
|
||||
if macro_enabled and macro_strength > 0:
|
||||
# Macro data adds information → slight confidence boost
|
||||
macro_conf_boost = min(macro_strength * 0.15, 0.1)
|
||||
projected_confidence = round(min(base_confidence + macro_conf_boost, 1.0), 6)
|
||||
else:
|
||||
# Without macro data, reduce confidence further
|
||||
if not macro_enabled:
|
||||
projected_confidence = round(base_confidence * 0.85, 6)
|
||||
else:
|
||||
projected_confidence = round(base_confidence, 6)
|
||||
|
||||
# Ensure driving_factors is never empty
|
||||
if not driving_factors:
|
||||
driving_factors.append(f"Baseline trend continuation: {current_dir} at strength {current_strength:.3f}")
|
||||
|
||||
# 5. Flag divergence
|
||||
diverges = projected_direction != current_dir
|
||||
if diverges:
|
||||
driving_factors.append(
|
||||
f"DIVERGENCE: Current trend is {current_dir}, "
|
||||
f"projection is {projected_direction}"
|
||||
)
|
||||
|
||||
# Mark low confidence
|
||||
is_low_confidence = projected_confidence < confidence_threshold
|
||||
|
||||
return TrendProjection(
|
||||
projected_direction=projected_direction,
|
||||
projected_strength=projected_strength,
|
||||
projected_confidence=projected_confidence,
|
||||
projection_horizon=horizon,
|
||||
driving_factors=driving_factors,
|
||||
macro_contribution_pct=macro_contribution,
|
||||
diverges_from_current=diverges,
|
||||
computed_at=now,
|
||||
low_confidence=is_low_confidence,
|
||||
)
|
||||
|
||||
|
||||
def _signed_to_direction(signed_value: float) -> str:
|
||||
"""Convert a signed strength value to a direction string."""
|
||||
if signed_value > 0.1:
|
||||
return "bullish"
|
||||
elif signed_value < -0.1:
|
||||
return "bearish"
|
||||
elif abs(signed_value) > 0.02:
|
||||
return "mixed"
|
||||
return "neutral"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INSERT_PROJECTION = """
|
||||
INSERT INTO trend_projections (
|
||||
trend_window_id, projected_direction, projected_strength,
|
||||
projected_confidence, projection_horizon, driving_factors,
|
||||
macro_contribution_pct, diverges_from_current, computed_at
|
||||
) VALUES (
|
||||
$1::uuid, $2, $3, $4, $5, $6::jsonb, $7, $8, $9
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
|
||||
async def persist_trend_projection(
|
||||
pool: asyncpg.Pool,
|
||||
trend_window_id: str,
|
||||
projection: TrendProjection,
|
||||
) -> str:
|
||||
"""Persist a TrendProjection to the trend_projections table.
|
||||
|
||||
Returns the row UUID.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
_INSERT_PROJECTION,
|
||||
trend_window_id,
|
||||
projection.projected_direction,
|
||||
projection.projected_strength,
|
||||
projection.projected_confidence,
|
||||
projection.projection_horizon,
|
||||
json.dumps(projection.driving_factors),
|
||||
projection.macro_contribution_pct,
|
||||
projection.diverges_from_current,
|
||||
projection.computed_at,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted trend projection for window=%s: direction=%s strength=%.3f confidence=%.3f diverges=%s",
|
||||
trend_window_id,
|
||||
projection.projected_direction,
|
||||
projection.projected_strength,
|
||||
projection.projected_confidence,
|
||||
projection.diverges_from_current,
|
||||
)
|
||||
return str(row_id)
|
||||
+226
-13
@@ -4,13 +4,13 @@ Aggregates company-level trend summaries into sector and market-level
|
||||
summaries, enabling top-down views of sentiment and risk across the
|
||||
portfolio.
|
||||
|
||||
Requirements: 6.3, 6.4, 6.5
|
||||
Requirements: 6.1, 6.2, 6.3, 6.4, 6.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import asyncpg
|
||||
@@ -42,6 +42,126 @@ class CompanyTrendRow:
|
||||
top_opposing_evidence: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectorMacroImpact:
|
||||
"""Aggregated macro impact data for a single sector.
|
||||
|
||||
Used to incorporate macro signals into sector and market rollups.
|
||||
Requirements: 6.1, 6.2, 6.3
|
||||
"""
|
||||
|
||||
sector: str
|
||||
total_impact: float # sum of macro_impact_score across companies in sector
|
||||
avg_impact: float # average macro_impact_score
|
||||
company_count: int # number of companies affected
|
||||
net_direction: float # weighted direction: +1 positive, -1 negative, 0 mixed
|
||||
event_ids: list[str] = field(default_factory=list) # contributing event IDs
|
||||
|
||||
|
||||
# Threshold for disproportionate sector impact (Requirement 6.3)
|
||||
SECTOR_CONCENTRATION_THRESHOLD = 0.60
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch sector-level macro impact aggregates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SECTOR_MACRO_IMPACT_QUERY = """
|
||||
SELECT
|
||||
c.sector,
|
||||
mir.event_id,
|
||||
mir.macro_impact_score,
|
||||
mir.impact_direction
|
||||
FROM macro_impact_records mir
|
||||
JOIN companies c ON c.id = mir.company_id AND c.active = TRUE
|
||||
WHERE mir.computed_at >= $1
|
||||
AND mir.computed_at <= $2
|
||||
ORDER BY c.sector, mir.macro_impact_score DESC
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_sector_macro_impacts(
|
||||
pool: asyncpg.Pool,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> dict[str, SectorMacroImpact]:
|
||||
"""Fetch macro impact records aggregated by sector for a time range.
|
||||
|
||||
Returns a mapping of sector name to SectorMacroImpact.
|
||||
"""
|
||||
rows = await pool.fetch(_SECTOR_MACRO_IMPACT_QUERY, window_start, window_end)
|
||||
|
||||
# Accumulate per-sector
|
||||
sector_data: dict[str, dict] = {}
|
||||
direction_map = {"positive": 1.0, "negative": -1.0, "mixed": 0.0, "neutral": 0.0}
|
||||
|
||||
for row in rows:
|
||||
sector = str(row["sector"]) if row["sector"] else "Unknown"
|
||||
score = float(row["macro_impact_score"] or 0.0)
|
||||
direction = row["impact_direction"] or "neutral"
|
||||
event_id = str(row["event_id"])
|
||||
|
||||
if sector not in sector_data:
|
||||
sector_data[sector] = {
|
||||
"total": 0.0,
|
||||
"count": 0,
|
||||
"dir_sum": 0.0,
|
||||
"dir_count": 0,
|
||||
"event_ids": set(),
|
||||
}
|
||||
|
||||
d = sector_data[sector]
|
||||
d["total"] += score
|
||||
d["count"] += 1
|
||||
dir_val = direction_map.get(direction, 0.0)
|
||||
if dir_val != 0.0:
|
||||
d["dir_sum"] += dir_val
|
||||
d["dir_count"] += 1
|
||||
d["event_ids"].add(event_id)
|
||||
|
||||
result: dict[str, SectorMacroImpact] = {}
|
||||
for sector, d in sector_data.items():
|
||||
count = d["count"]
|
||||
avg = d["total"] / count if count > 0 else 0.0
|
||||
net_dir = d["dir_sum"] / d["dir_count"] if d["dir_count"] > 0 else 0.0
|
||||
result[sector] = SectorMacroImpact(
|
||||
sector=sector,
|
||||
total_impact=d["total"],
|
||||
avg_impact=avg,
|
||||
company_count=count,
|
||||
net_direction=net_dir,
|
||||
event_ids=sorted(d["event_ids"]),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sector macro concentration helper (Requirement 6.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_sector_macro_concentration(
|
||||
sector_impacts: dict[str, SectorMacroImpact],
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Compute the fraction of total macro impact concentrated in each sector.
|
||||
|
||||
Returns a list of (sector, fraction) tuples sorted by fraction descending.
|
||||
Sectors with fraction > SECTOR_CONCENTRATION_THRESHOLD are considered
|
||||
disproportionately affected.
|
||||
"""
|
||||
total = sum(si.total_impact for si in sector_impacts.values())
|
||||
if total <= 0.0:
|
||||
return []
|
||||
|
||||
fractions = [
|
||||
(sector, si.total_impact / total)
|
||||
for sector, si in sector_impacts.items()
|
||||
]
|
||||
fractions.sort(key=lambda x: x[1], reverse=True)
|
||||
return fractions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch latest company trends for a given window
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,11 +261,22 @@ def rollup_trends(
|
||||
entity_id: str,
|
||||
window: str,
|
||||
reference_time: datetime,
|
||||
macro_impacts: dict[str, SectorMacroImpact] | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Aggregate a list of company-level trends into a single rollup summary.
|
||||
|
||||
Each company trend is weighted by its confidence to produce a
|
||||
confidence-weighted average of direction, strength, and contradiction.
|
||||
|
||||
When macro_impacts is provided:
|
||||
- For sector rollups: incorporates the sector's macro signal into
|
||||
strength and confidence, weighted by constituent company exposure.
|
||||
- For market rollups: aggregates macro signals across all sectors and
|
||||
surfaces disproportionately affected sectors (>60% concentration)
|
||||
in material_risks or dominant_catalysts.
|
||||
|
||||
When macro_impacts is None or empty, produces identical output to
|
||||
the original company-only rollup.
|
||||
"""
|
||||
if not trends:
|
||||
return TrendSummary(
|
||||
@@ -204,16 +335,70 @@ def rollup_trends(
|
||||
avg_contradiction = weighted_contradiction / total_weight
|
||||
avg_confidence = total_weight / len(trends)
|
||||
|
||||
# --- Incorporate macro impact signals when available ---
|
||||
macro_strength_adj = 0.0
|
||||
macro_confidence_adj = 0.0
|
||||
macro_catalysts: list[str] = []
|
||||
macro_risks: list[str] = []
|
||||
|
||||
if macro_impacts:
|
||||
if entity_type == "sector":
|
||||
# Sector rollup: incorporate this sector's macro signal
|
||||
sector_macro = macro_impacts.get(entity_id)
|
||||
if sector_macro and sector_macro.total_impact > 0:
|
||||
# Weight macro contribution by avg impact and company breadth
|
||||
breadth = min(sector_macro.company_count / max(len(trends), 1), 1.0)
|
||||
macro_strength_adj = sector_macro.avg_impact * breadth * 0.3
|
||||
macro_confidence_adj = sector_macro.avg_impact * breadth * 0.1
|
||||
# Nudge direction based on macro net direction
|
||||
avg_direction += sector_macro.net_direction * macro_strength_adj * 0.5
|
||||
|
||||
elif entity_type == "market":
|
||||
# Market rollup: aggregate macro signals across all sectors
|
||||
total_macro = sum(si.total_impact for si in macro_impacts.values())
|
||||
if total_macro > 0:
|
||||
total_companies = sum(si.company_count for si in macro_impacts.values())
|
||||
breadth = min(total_companies / max(len(trends), 1), 1.0)
|
||||
avg_macro = total_macro / max(len(macro_impacts), 1)
|
||||
macro_strength_adj = avg_macro * breadth * 0.3
|
||||
macro_confidence_adj = avg_macro * breadth * 0.1
|
||||
|
||||
# Aggregate net direction across sectors
|
||||
dir_sum = sum(
|
||||
si.net_direction * si.total_impact
|
||||
for si in macro_impacts.values()
|
||||
)
|
||||
net_dir = dir_sum / total_macro if total_macro > 0 else 0.0
|
||||
avg_direction += net_dir * macro_strength_adj * 0.5
|
||||
|
||||
# Surface disproportionately affected sectors (Requirement 6.3)
|
||||
concentration = compute_sector_macro_concentration(macro_impacts)
|
||||
for sector, fraction in concentration:
|
||||
if fraction > SECTOR_CONCENTRATION_THRESHOLD:
|
||||
si = macro_impacts[sector]
|
||||
label = f"Macro: {sector} ({fraction:.0%} of macro impact)"
|
||||
if si.net_direction < 0:
|
||||
macro_risks.append(label)
|
||||
else:
|
||||
macro_catalysts.append(label)
|
||||
|
||||
# Apply macro adjustments to strength and confidence
|
||||
adj_strength = avg_strength + macro_strength_adj
|
||||
adj_confidence = avg_confidence + macro_confidence_adj
|
||||
|
||||
# Derive direction
|
||||
direction = _derive_rollup_direction(avg_direction, avg_contradiction)
|
||||
|
||||
# Top catalysts
|
||||
# Top catalysts (macro catalysts prepended when present)
|
||||
sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True)
|
||||
catalysts = [c for c, _ in sorted_catalysts[:5]]
|
||||
catalysts = macro_catalysts + [c for c, _ in sorted_catalysts[:5]]
|
||||
catalysts = catalysts[:5]
|
||||
|
||||
# Top risks (deduplicated, by weight)
|
||||
# Top risks (macro risks prepended when present, deduplicated)
|
||||
sorted_risks = sorted(risk_set.items(), key=lambda x: x[1], reverse=True)
|
||||
risks = [r for r, _ in sorted_risks[:5]]
|
||||
base_risks = [r for r, _ in sorted_risks[:5]]
|
||||
risks = macro_risks + base_risks
|
||||
risks = risks[:5]
|
||||
|
||||
# Disagreement details
|
||||
disagreement = _build_rollup_disagreement(trends, entity_id)
|
||||
@@ -223,8 +408,8 @@ def rollup_trends(
|
||||
entity_id=entity_id,
|
||||
window=TrendWindow(window),
|
||||
trend_direction=direction,
|
||||
trend_strength=round(min(abs(avg_strength), 1.0), 4),
|
||||
confidence=round(max(0.0, min(avg_confidence, 1.0)), 4),
|
||||
trend_strength=round(min(abs(adj_strength), 1.0), 4),
|
||||
confidence=round(max(0.0, min(adj_confidence, 1.0)), 4),
|
||||
top_supporting_evidence=list(dict.fromkeys(all_supporting))[:10],
|
||||
top_opposing_evidence=list(dict.fromkeys(all_opposing))[:10],
|
||||
dominant_catalysts=catalysts,
|
||||
@@ -341,11 +526,14 @@ async def aggregate_sector(
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
macro_impacts: dict[str, SectorMacroImpact] | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Compute and persist a sector-level rollup for one window.
|
||||
|
||||
Fetches the latest company trends, filters to the given sector,
|
||||
and rolls them up into a single sector summary.
|
||||
and rolls them up into a single sector summary. When macro_impacts
|
||||
is provided, incorporates macro signals weighted by constituent
|
||||
company exposure.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
@@ -355,7 +543,14 @@ async def aggregate_sector(
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
sector_trends = [t for t in all_trends if t.sector == sector]
|
||||
|
||||
summary = rollup_trends(sector_trends, "sector", sector, window, reference_time)
|
||||
# Fetch macro impacts if not provided
|
||||
if macro_impacts is None:
|
||||
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
|
||||
|
||||
summary = rollup_trends(
|
||||
sector_trends, "sector", sector, window, reference_time,
|
||||
macro_impacts=macro_impacts,
|
||||
)
|
||||
|
||||
if sector_trends:
|
||||
rollup_id = await persist_rollup(pool, summary)
|
||||
@@ -373,10 +568,13 @@ async def aggregate_market(
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
macro_impacts: dict[str, SectorMacroImpact] | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Compute and persist a market-wide rollup for one window.
|
||||
|
||||
Aggregates all company trends regardless of sector.
|
||||
Aggregates all company trends regardless of sector. When macro_impacts
|
||||
is provided, aggregates macro signals across all sectors and surfaces
|
||||
disproportionately affected sectors in material_risks or dominant_catalysts.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
@@ -385,7 +583,14 @@ async def aggregate_market(
|
||||
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
|
||||
summary = rollup_trends(all_trends, "market", "all", window, reference_time)
|
||||
# Fetch macro impacts if not provided
|
||||
if macro_impacts is None:
|
||||
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
|
||||
|
||||
summary = rollup_trends(
|
||||
all_trends, "market", "all", window, reference_time,
|
||||
macro_impacts=macro_impacts,
|
||||
)
|
||||
|
||||
if all_trends:
|
||||
rollup_id = await persist_rollup(pool, summary)
|
||||
@@ -403,6 +608,7 @@ async def aggregate_all_sectors(
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
macro_impacts: dict[str, SectorMacroImpact] | None = None,
|
||||
) -> list[TrendSummary]:
|
||||
"""Compute sector rollups for every sector that has company trends."""
|
||||
if reference_time is None:
|
||||
@@ -412,6 +618,10 @@ async def aggregate_all_sectors(
|
||||
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
|
||||
# Fetch macro impacts once for all sectors if not provided
|
||||
if macro_impacts is None:
|
||||
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
|
||||
|
||||
# Group by sector
|
||||
sectors: dict[str, list[CompanyTrendRow]] = {}
|
||||
for t in all_trends:
|
||||
@@ -419,7 +629,10 @@ async def aggregate_all_sectors(
|
||||
|
||||
summaries: list[TrendSummary] = []
|
||||
for sector, trends in sectors.items():
|
||||
summary = rollup_trends(trends, "sector", sector, window, reference_time)
|
||||
summary = rollup_trends(
|
||||
trends, "sector", sector, window, reference_time,
|
||||
macro_impacts=macro_impacts,
|
||||
)
|
||||
if trends:
|
||||
_id = await persist_rollup(pool, summary)
|
||||
summaries.append(summary)
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
"""Competitive signal propagation engine.
|
||||
|
||||
Evaluates incoming document intelligence, identifies competitors via
|
||||
the competitor_relationships table, queries historical cross-company
|
||||
patterns, and produces weighted competitive signals persisted to
|
||||
competitive_signal_records.
|
||||
|
||||
Also converts pattern and competitive signals into WeightedSignal
|
||||
objects for the aggregation engine.
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 9.1
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.aggregation.pattern_matcher import (
|
||||
HistoricalPattern,
|
||||
find_cross_company_patterns,
|
||||
)
|
||||
from services.aggregation.scoring import (
|
||||
ScoringConfig,
|
||||
WeightedSignal,
|
||||
compute_signal_weight,
|
||||
)
|
||||
from services.shared.config import CompetitiveConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class CompetitiveSignalRecord:
|
||||
"""A competitive signal produced by propagating a source event to a
|
||||
competitor based on historical cross-company patterns."""
|
||||
|
||||
source_document_id: str
|
||||
source_ticker: str
|
||||
target_ticker: str
|
||||
catalyst_type: str
|
||||
pattern_confidence: float
|
||||
signal_direction: str # bullish | bearish
|
||||
signal_strength: float # [0, 1]
|
||||
relationship_strength: float
|
||||
computed_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMPETITOR_LOOKUP_QUERY = """
|
||||
SELECT company_a_id, company_b_id, strength
|
||||
FROM competitor_relationships
|
||||
WHERE (company_a_id = $1 OR company_b_id = $1)
|
||||
AND active = TRUE
|
||||
"""
|
||||
|
||||
_INSERT_SIGNAL_QUERY = """
|
||||
INSERT INTO competitive_signal_records
|
||||
(source_document_id, source_ticker, target_ticker, catalyst_type,
|
||||
pattern_confidence, signal_direction, signal_strength,
|
||||
relationship_strength, computed_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# propagate_signals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def propagate_signals(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
catalyst_type: str,
|
||||
impact_score: float,
|
||||
document_id: str,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> list[CompetitiveSignalRecord]:
|
||||
"""Look up competitors, query cross-company patterns, produce weighted
|
||||
competitive signals, and persist them.
|
||||
|
||||
Args:
|
||||
pool: asyncpg connection pool.
|
||||
ticker: Source company ticker that received the catalyst.
|
||||
catalyst_type: The catalyst type from document intelligence.
|
||||
impact_score: The source document's impact score.
|
||||
document_id: The source document ID.
|
||||
config: Optional competitive config overrides.
|
||||
|
||||
Returns:
|
||||
List of CompetitiveSignalRecord objects produced and persisted.
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
now = datetime.now(timezone.utc)
|
||||
records: list[CompetitiveSignalRecord] = []
|
||||
|
||||
# Step 1: Look up active competitors
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(_COMPETITOR_LOOKUP_QUERY, ticker)
|
||||
except Exception:
|
||||
logger.exception("Failed to look up competitors for %s", ticker)
|
||||
return records
|
||||
|
||||
if not rows:
|
||||
logger.debug("No active competitors found for %s", ticker)
|
||||
return records
|
||||
|
||||
# Step 2: For each competitor, query cross-company patterns
|
||||
for row in rows:
|
||||
company_a = str(row["company_a_id"])
|
||||
company_b = str(row["company_b_id"])
|
||||
rel_strength = float(row["strength"])
|
||||
|
||||
# Determine the competitor ticker (the other side of the relationship)
|
||||
competitor_ticker = company_b if company_a == ticker else company_a
|
||||
|
||||
# Threshold gating (Req 4.5)
|
||||
if rel_strength < cfg.propagation_strength_threshold:
|
||||
logger.info(
|
||||
"Skipping propagation %s→%s: relationship strength %.3f "
|
||||
"below threshold %.3f",
|
||||
ticker, competitor_ticker, rel_strength,
|
||||
cfg.propagation_strength_threshold,
|
||||
)
|
||||
continue
|
||||
|
||||
# Query cross-company patterns
|
||||
try:
|
||||
patterns = await find_cross_company_patterns(
|
||||
pool, ticker, competitor_ticker, catalyst_type, config=cfg,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to query cross-company patterns for %s→%s/%s",
|
||||
ticker, competitor_ticker, catalyst_type,
|
||||
)
|
||||
continue
|
||||
|
||||
for pattern in patterns:
|
||||
# Confidence threshold gating (Req 9.1)
|
||||
if pattern.pattern_confidence < cfg.pattern_confidence_threshold:
|
||||
logger.info(
|
||||
"Excluding pattern %s→%s/%s/%s: confidence %.3f "
|
||||
"below threshold %.3f",
|
||||
ticker, competitor_ticker, catalyst_type,
|
||||
pattern.time_horizon, pattern.pattern_confidence,
|
||||
cfg.pattern_confidence_threshold,
|
||||
)
|
||||
continue
|
||||
|
||||
# Compute signal strength (Req 4.3)
|
||||
raw_strength = (
|
||||
pattern.avg_strength
|
||||
* rel_strength
|
||||
* pattern.pattern_confidence
|
||||
* impact_score
|
||||
)
|
||||
signal_strength = min(max(raw_strength, 0.0), 1.0)
|
||||
|
||||
# Determine direction
|
||||
direction = (
|
||||
"bullish" if pattern.bullish_pct > pattern.bearish_pct
|
||||
else "bearish"
|
||||
)
|
||||
|
||||
record = CompetitiveSignalRecord(
|
||||
source_document_id=document_id,
|
||||
source_ticker=ticker,
|
||||
target_ticker=competitor_ticker,
|
||||
catalyst_type=catalyst_type,
|
||||
pattern_confidence=pattern.pattern_confidence,
|
||||
signal_direction=direction,
|
||||
signal_strength=signal_strength,
|
||||
relationship_strength=rel_strength,
|
||||
computed_at=now,
|
||||
)
|
||||
records.append(record)
|
||||
|
||||
# Step 3: Persist all records
|
||||
if records:
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.executemany(
|
||||
_INSERT_SIGNAL_QUERY,
|
||||
[
|
||||
(
|
||||
r.source_document_id,
|
||||
r.source_ticker,
|
||||
r.target_ticker,
|
||||
r.catalyst_type,
|
||||
r.pattern_confidence,
|
||||
r.signal_direction,
|
||||
r.signal_strength,
|
||||
r.relationship_strength,
|
||||
r.computed_at,
|
||||
)
|
||||
for r in records
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to persist %d competitive signal records", len(records),
|
||||
)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_pattern_weighted_signals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_pattern_weighted_signals(
|
||||
patterns: list[HistoricalPattern],
|
||||
competitive_signals: list[CompetitiveSignalRecord],
|
||||
reference_time: datetime,
|
||||
window: str,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> list[WeightedSignal]:
|
||||
"""Convert pattern and competitive signal objects to WeightedSignal
|
||||
objects for the aggregation engine.
|
||||
|
||||
For HistoricalPattern objects:
|
||||
- sentiment_value = +1.0 if bullish_pct > bearish_pct else -1.0
|
||||
- impact_score = avg_strength * competitive_signal_weight
|
||||
- published_at = data_end (most recent data point for recency decay)
|
||||
- extraction_confidence = pattern_confidence
|
||||
|
||||
For CompetitiveSignalRecord objects:
|
||||
- sentiment_value = +1.0 if direction == "bullish" else -1.0
|
||||
- impact_score = signal_strength * competitive_signal_weight
|
||||
- published_at = computed_at (for recency decay)
|
||||
- extraction_confidence = pattern_confidence
|
||||
|
||||
Args:
|
||||
patterns: Self-company historical patterns.
|
||||
competitive_signals: Competitive signal records from propagation.
|
||||
reference_time: Aggregation anchor time for recency decay.
|
||||
window: Trend window identifier (e.g. "7d").
|
||||
config: Optional competitive config overrides.
|
||||
|
||||
Returns:
|
||||
List of WeightedSignal objects ready for aggregation.
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
scoring_cfg = ScoringConfig()
|
||||
signals: list[WeightedSignal] = []
|
||||
|
||||
# Convert HistoricalPattern objects
|
||||
for pattern in patterns:
|
||||
sentiment_value = (
|
||||
1.0 if pattern.bullish_pct > pattern.bearish_pct else -1.0
|
||||
)
|
||||
impact = pattern.avg_strength * cfg.competitive_signal_weight
|
||||
|
||||
weight = compute_signal_weight(
|
||||
published_at=pattern.data_end,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
source_credibility=1.0, # patterns are derived from validated data
|
||||
novelty_score=0.5,
|
||||
extraction_confidence=pattern.pattern_confidence,
|
||||
market_ctx=None,
|
||||
config=scoring_cfg,
|
||||
)
|
||||
|
||||
signals.append(WeightedSignal(
|
||||
document_id=f"pattern:{pattern.source_ticker}:{pattern.catalyst_type}:{pattern.time_horizon}",
|
||||
weight=weight,
|
||||
sentiment_value=sentiment_value,
|
||||
impact_score=impact,
|
||||
))
|
||||
|
||||
# Convert CompetitiveSignalRecord objects
|
||||
for sig in competitive_signals:
|
||||
sentiment_value = 1.0 if sig.signal_direction == "bullish" else -1.0
|
||||
impact = sig.signal_strength * cfg.competitive_signal_weight
|
||||
|
||||
weight = compute_signal_weight(
|
||||
published_at=sig.computed_at,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
source_credibility=1.0,
|
||||
novelty_score=0.5,
|
||||
extraction_confidence=sig.pattern_confidence,
|
||||
market_ctx=None,
|
||||
config=scoring_cfg,
|
||||
)
|
||||
|
||||
signals.append(WeightedSignal(
|
||||
document_id=sig.source_document_id,
|
||||
weight=weight,
|
||||
sentiment_value=sentiment_value,
|
||||
impact_score=impact,
|
||||
))
|
||||
|
||||
return signals
|
||||
@@ -40,6 +40,17 @@ from services.shared.metrics import (
|
||||
AGGREGATION_SIGNALS_PROCESSED,
|
||||
AGGREGATION_WINDOWS_COMPUTED,
|
||||
)
|
||||
from services.aggregation.pattern_matcher import find_self_patterns
|
||||
from services.aggregation.projection import (
|
||||
MacroEventInfo,
|
||||
TrendProjection,
|
||||
compute_projection,
|
||||
persist_trend_projection,
|
||||
)
|
||||
from services.aggregation.signal_propagation import (
|
||||
CompetitiveSignalRecord,
|
||||
build_pattern_weighted_signals,
|
||||
)
|
||||
from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,6 +75,10 @@ class AggregationConfig:
|
||||
windows: list[str] | None = None # None = all windows
|
||||
scoring: ScoringConfig | None = None
|
||||
max_evidence: int = MAX_EVIDENCE_REFS
|
||||
macro_signal_weight: float = 0.3 # relative weight of macro vs company signals
|
||||
macro_enabled: bool = True # runtime toggle state
|
||||
competitive_signal_weight: float = 0.2 # relative weight of pattern signals
|
||||
competitive_enabled: bool = True # runtime toggle state
|
||||
|
||||
def effective_windows(self) -> list[str]:
|
||||
if self.windows:
|
||||
@@ -154,6 +169,236 @@ async def fetch_impact_records(
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch macro toggle state from risk_configs
|
||||
#
|
||||
# MACRO LAYER TOGGLE BEHAVIOR (Requirements 11.2, 11.3, 11.4):
|
||||
# - The toggle state is read fresh from PostgreSQL at the start of each
|
||||
# aggregation cycle (no caching), so changes take effect immediately on
|
||||
# the next cycle.
|
||||
# - When disabled: ingestion and classification continue normally (historical
|
||||
# data is preserved), but interpolation and aggregation integration are
|
||||
# skipped — the aggregation engine produces trends using only company-
|
||||
# specific signals.
|
||||
# - When re-enabled: the engine resumes computing macro impact scores using
|
||||
# the most recent GlobalEvent classifications, including any events that
|
||||
# were ingested and classified while the layer was disabled.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MACRO_TOGGLE_QUERY = """
|
||||
SELECT config->>'macro_enabled' AS macro_enabled
|
||||
FROM risk_configs
|
||||
WHERE active = TRUE
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_macro_enabled(pool: asyncpg.Pool) -> bool | None:
|
||||
"""Check macro toggle state from risk_configs table.
|
||||
|
||||
Returns True/False if explicitly set, or None if no config exists
|
||||
(caller should fall back to AggregationConfig default).
|
||||
"""
|
||||
row = await pool.fetchrow(_MACRO_TOGGLE_QUERY)
|
||||
if row is None or row["macro_enabled"] is None:
|
||||
return None
|
||||
return row["macro_enabled"].lower() == "true"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch competitive toggle state from risk_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMPETITIVE_TOGGLE_QUERY = """
|
||||
SELECT config->>'competitive_enabled' AS competitive_enabled
|
||||
FROM risk_configs
|
||||
WHERE active = TRUE
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_competitive_enabled(pool: asyncpg.Pool) -> bool | None:
|
||||
"""Check competitive toggle state from risk_configs table.
|
||||
|
||||
Returns True/False if explicitly set, or None if no config exists
|
||||
(caller should fall back to AggregationConfig default).
|
||||
"""
|
||||
row = await pool.fetchrow(_COMPETITIVE_TOGGLE_QUERY)
|
||||
if row is None or row["competitive_enabled"] is None:
|
||||
return None
|
||||
return row["competitive_enabled"].lower() == "true"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch competitive signals targeting a ticker within a time window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMPETITIVE_SIGNALS_QUERY = """
|
||||
SELECT source_document_id, source_ticker, target_ticker, catalyst_type,
|
||||
pattern_confidence, signal_direction, signal_strength,
|
||||
relationship_strength, computed_at
|
||||
FROM competitive_signal_records
|
||||
WHERE target_ticker = $1
|
||||
AND computed_at >= $2
|
||||
AND computed_at <= $3
|
||||
ORDER BY computed_at DESC
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_competitive_signals(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> list[CompetitiveSignalRecord]:
|
||||
"""Fetch competitive signal records targeting a ticker in a time range."""
|
||||
rows = await pool.fetch(
|
||||
_COMPETITIVE_SIGNALS_QUERY, ticker, window_start, window_end,
|
||||
)
|
||||
return [
|
||||
CompetitiveSignalRecord(
|
||||
source_document_id=str(row["source_document_id"]),
|
||||
source_ticker=row["source_ticker"],
|
||||
target_ticker=row["target_ticker"],
|
||||
catalyst_type=row["catalyst_type"],
|
||||
pattern_confidence=float(row["pattern_confidence"]),
|
||||
signal_direction=row["signal_direction"],
|
||||
signal_strength=float(row["signal_strength"]),
|
||||
relationship_strength=float(row["relationship_strength"]),
|
||||
computed_at=row["computed_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch macro impact records for a ticker within a time window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MACRO_IMPACT_QUERY = """
|
||||
SELECT
|
||||
mir.event_id,
|
||||
mir.company_id,
|
||||
mir.ticker,
|
||||
mir.macro_impact_score,
|
||||
mir.impact_direction,
|
||||
mir.contributing_factors,
|
||||
mir.confidence,
|
||||
mir.computed_at,
|
||||
ge.source_document_id,
|
||||
d.published_at AS event_published_at
|
||||
FROM macro_impact_records mir
|
||||
JOIN global_events ge ON ge.id = mir.event_id
|
||||
JOIN documents d ON d.id = ge.source_document_id
|
||||
WHERE mir.ticker = $1
|
||||
AND mir.computed_at >= $2
|
||||
AND mir.computed_at <= $3
|
||||
ORDER BY mir.computed_at DESC
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacroImpactRow:
|
||||
"""Parsed row from the macro impact query."""
|
||||
|
||||
event_id: str
|
||||
company_id: str
|
||||
ticker: str
|
||||
macro_impact_score: float
|
||||
impact_direction: str
|
||||
contributing_factors: list[str]
|
||||
confidence: float
|
||||
computed_at: datetime
|
||||
source_document_id: str
|
||||
event_published_at: datetime
|
||||
|
||||
|
||||
def _parse_macro_impact_row(row: Any) -> MacroImpactRow:
|
||||
"""Convert an asyncpg Record to a MacroImpactRow."""
|
||||
factors = row["contributing_factors"]
|
||||
if isinstance(factors, str):
|
||||
factors = json.loads(factors)
|
||||
|
||||
return MacroImpactRow(
|
||||
event_id=str(row["event_id"]),
|
||||
company_id=str(row["company_id"]),
|
||||
ticker=row["ticker"],
|
||||
macro_impact_score=float(row["macro_impact_score"] or 0.0),
|
||||
impact_direction=row["impact_direction"] or "neutral",
|
||||
contributing_factors=factors if isinstance(factors, list) else [],
|
||||
confidence=float(row["confidence"] or 0.5),
|
||||
computed_at=row["computed_at"],
|
||||
source_document_id=str(row["source_document_id"]),
|
||||
event_published_at=row["event_published_at"],
|
||||
)
|
||||
|
||||
|
||||
async def fetch_macro_impact_records(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> list[MacroImpactRow]:
|
||||
"""Fetch macro impact records for a ticker in a time range."""
|
||||
rows = await pool.fetch(_MACRO_IMPACT_QUERY, ticker, window_start, window_end)
|
||||
return [_parse_macro_impact_row(r) for r in rows]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convert macro impact records to WeightedSignals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DIRECTION_TO_SENTIMENT: dict[str, float] = {
|
||||
"positive": 1.0,
|
||||
"negative": -1.0,
|
||||
"mixed": 0.0,
|
||||
"neutral": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def build_macro_weighted_signals(
|
||||
macro_impacts: list[MacroImpactRow],
|
||||
reference_time: datetime,
|
||||
window: str,
|
||||
macro_signal_weight: float = 0.3,
|
||||
config: ScoringConfig | None = None,
|
||||
) -> list[WeightedSignal]:
|
||||
"""Convert macro impact records into WeightedSignal objects.
|
||||
|
||||
Uses the same scoring pipeline as company signals:
|
||||
- document_id = source_document_id (for evidence tracing)
|
||||
- sentiment_value mapped from impact_direction
|
||||
- impact_score = macro_impact_score * macro_signal_weight
|
||||
- recency decay from the global event's publication time
|
||||
- confidence gating from the macro record's confidence
|
||||
"""
|
||||
cfg = config or ScoringConfig()
|
||||
signals: list[WeightedSignal] = []
|
||||
for mir in macro_impacts:
|
||||
sw = compute_signal_weight(
|
||||
published_at=mir.event_published_at,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
source_credibility=mir.confidence,
|
||||
novelty_score=0.5,
|
||||
extraction_confidence=mir.confidence,
|
||||
config=cfg,
|
||||
)
|
||||
sentiment = _DIRECTION_TO_SENTIMENT.get(mir.impact_direction, 0.0)
|
||||
impact = mir.macro_impact_score * macro_signal_weight
|
||||
signals.append(
|
||||
WeightedSignal(
|
||||
document_id=mir.source_document_id,
|
||||
weight=sw,
|
||||
sentiment_value=sentiment,
|
||||
impact_score=impact,
|
||||
)
|
||||
)
|
||||
return signals
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build weighted signals from impact records
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -544,6 +789,61 @@ async def persist_trend_evidence(
|
||||
return len(rows)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build MacroEventInfo objects for projection computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MACRO_EVENT_INFO_QUERY = """
|
||||
SELECT
|
||||
mir.event_id,
|
||||
mir.macro_impact_score,
|
||||
mir.impact_direction,
|
||||
mir.confidence,
|
||||
ge.estimated_duration,
|
||||
ge.severity,
|
||||
d.published_at AS event_published_at
|
||||
FROM macro_impact_records mir
|
||||
JOIN global_events ge ON ge.id = mir.event_id
|
||||
JOIN documents d ON d.id = ge.source_document_id
|
||||
WHERE mir.ticker = $1
|
||||
AND mir.computed_at >= $2
|
||||
AND mir.computed_at <= $3
|
||||
ORDER BY mir.computed_at DESC
|
||||
"""
|
||||
|
||||
|
||||
async def _build_macro_event_infos(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window_start: datetime,
|
||||
reference_time: datetime,
|
||||
) -> list[MacroEventInfo]:
|
||||
"""Fetch macro impact records and build MacroEventInfo objects for projection."""
|
||||
rows = await pool.fetch(
|
||||
_MACRO_EVENT_INFO_QUERY, ticker, window_start, reference_time,
|
||||
)
|
||||
infos: list[MacroEventInfo] = []
|
||||
for row in rows:
|
||||
published_at = row["event_published_at"]
|
||||
age_hours = 0.0
|
||||
if published_at:
|
||||
age_hours = max(
|
||||
(reference_time - published_at).total_seconds() / 3600.0, 0.0,
|
||||
)
|
||||
infos.append(
|
||||
MacroEventInfo(
|
||||
event_id=str(row["event_id"]),
|
||||
macro_impact_score=float(row["macro_impact_score"] or 0.0),
|
||||
impact_direction=row["impact_direction"] or "neutral",
|
||||
confidence=float(row["confidence"] or 0.5),
|
||||
estimated_duration=row["estimated_duration"] or "short_term",
|
||||
severity=row["severity"] or "low",
|
||||
event_age_hours=age_hours,
|
||||
)
|
||||
)
|
||||
return infos
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main aggregation entry point for a single ticker + window
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -563,8 +863,10 @@ async def aggregate_company_window(
|
||||
2. Fetch document impact records from PostgreSQL.
|
||||
3. Fetch market context for the ticker.
|
||||
4. Build weighted signals using the scoring module.
|
||||
5. Assemble the TrendSummary.
|
||||
6. Persist to trend_windows table.
|
||||
5. Check macro toggle and fetch/merge macro signals if enabled.
|
||||
6. Check competitive toggle and fetch/merge pattern/competitive signals if enabled.
|
||||
7. Assemble the TrendSummary.
|
||||
8. Persist to trend_windows table.
|
||||
|
||||
Returns the assembled TrendSummary.
|
||||
"""
|
||||
@@ -589,7 +891,83 @@ async def aggregate_company_window(
|
||||
impacts, reference_time, window, market_ctx, scoring_cfg,
|
||||
)
|
||||
|
||||
# 4. Assemble trend summary with evidence details
|
||||
# 4. Check macro toggle and merge macro signals
|
||||
# (Requirement 11.2, 11.3, 11.4): Toggle state is read from the DB on
|
||||
# every aggregation cycle. When disabled, macro signals are skipped but
|
||||
# ingestion/classification continue independently — so when re-enabled,
|
||||
# the most recent classifications (including those ingested while disabled)
|
||||
# are immediately available for impact computation.
|
||||
macro_enabled = cfg.macro_enabled
|
||||
db_toggle = await fetch_macro_enabled(pool)
|
||||
if db_toggle is not None:
|
||||
macro_enabled = db_toggle
|
||||
|
||||
if macro_enabled:
|
||||
macro_impacts = await fetch_macro_impact_records(
|
||||
pool, ticker, window_start, reference_time,
|
||||
)
|
||||
if macro_impacts:
|
||||
macro_signals = build_macro_weighted_signals(
|
||||
macro_impacts,
|
||||
reference_time,
|
||||
window,
|
||||
macro_signal_weight=cfg.macro_signal_weight,
|
||||
config=scoring_cfg,
|
||||
)
|
||||
signals = signals + macro_signals
|
||||
logger.info(
|
||||
"Merged %d macro signals for %s/%s",
|
||||
len(macro_signals), ticker, window,
|
||||
)
|
||||
|
||||
# 5. Check competitive toggle and merge pattern/competitive signals
|
||||
# (Requirements 5.1-5.6): Same toggle pattern as macro layer. When
|
||||
# disabled, pattern mining remains queryable but aggregation skips
|
||||
# competitive signals — no degradation of existing behavior.
|
||||
competitive_enabled = cfg.competitive_enabled
|
||||
db_competitive_toggle = await fetch_competitive_enabled(pool)
|
||||
if db_competitive_toggle is not None:
|
||||
competitive_enabled = db_competitive_toggle
|
||||
|
||||
if competitive_enabled:
|
||||
try:
|
||||
# Get unique catalyst types from the impact records
|
||||
catalyst_types = {imp.catalyst_type for imp in impacts}
|
||||
|
||||
# Query self-company historical patterns for each catalyst type
|
||||
all_patterns = []
|
||||
for cat_type in catalyst_types:
|
||||
patterns = await find_self_patterns(pool, ticker, cat_type)
|
||||
all_patterns.extend(patterns)
|
||||
|
||||
# Fetch competitive signals targeting this ticker
|
||||
comp_signals = await fetch_competitive_signals(
|
||||
pool, ticker, window_start, reference_time,
|
||||
)
|
||||
|
||||
# Convert to WeightedSignal objects
|
||||
if all_patterns or comp_signals:
|
||||
pattern_weighted = build_pattern_weighted_signals(
|
||||
patterns=all_patterns,
|
||||
competitive_signals=comp_signals,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
)
|
||||
signals = signals + pattern_weighted
|
||||
logger.info(
|
||||
"Merged %d pattern/competitive signals for %s/%s "
|
||||
"(patterns=%d, competitive=%d)",
|
||||
len(pattern_weighted), ticker, window,
|
||||
len(all_patterns), len(comp_signals),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to fetch pattern/competitive signals for %s/%s — "
|
||||
"continuing with company+macro signals only",
|
||||
ticker, window,
|
||||
)
|
||||
|
||||
# 6. Assemble trend summary with evidence details
|
||||
assembled = assemble_trend_with_evidence(
|
||||
ticker=ticker,
|
||||
window=window,
|
||||
@@ -601,10 +979,10 @@ async def aggregate_company_window(
|
||||
)
|
||||
summary = assembled.summary
|
||||
|
||||
# 5. Persist trend window
|
||||
# 7. Persist trend window
|
||||
trend_id = await persist_trend_summary(pool, summary)
|
||||
|
||||
# 6. Persist evidence mappings
|
||||
# 8. Persist evidence mappings
|
||||
evidence_count = await persist_trend_evidence(
|
||||
pool, trend_id,
|
||||
assembled.supporting_evidence,
|
||||
@@ -617,6 +995,33 @@ async def aggregate_company_window(
|
||||
summary.trend_strength, summary.confidence, len(signals), evidence_count,
|
||||
)
|
||||
|
||||
# 9. Compute and persist trend projection
|
||||
try:
|
||||
macro_event_infos: list[MacroEventInfo] = []
|
||||
if macro_enabled:
|
||||
macro_event_infos = await _build_macro_event_infos(
|
||||
pool, ticker, window_start, reference_time,
|
||||
)
|
||||
|
||||
projection = compute_projection(
|
||||
summary=summary,
|
||||
macro_events=macro_event_infos if macro_event_infos else None,
|
||||
macro_enabled=macro_enabled,
|
||||
upcoming_catalysts=summary.dominant_catalysts[:3] if summary.dominant_catalysts else None,
|
||||
)
|
||||
await persist_trend_projection(pool, trend_id, projection)
|
||||
logger.info(
|
||||
"Persisted projection for %s/%s: direction=%s strength=%.3f confidence=%.3f diverges=%s",
|
||||
ticker, window, projection.projected_direction,
|
||||
projection.projected_strength, projection.projected_confidence,
|
||||
projection.diverges_from_current,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to compute/persist projection for trend %s (%s/%s) — continuing",
|
||||
trend_id, ticker, window,
|
||||
)
|
||||
|
||||
# Prometheus metrics
|
||||
AGGREGATION_WINDOWS_COMPUTED.labels(window=window).inc()
|
||||
AGGREGATION_SIGNALS_PROCESSED.labels(window=window).inc(len(signals))
|
||||
|
||||
Reference in New Issue
Block a user