feat: competitive intelligence & historical pattern matching layer

This commit is contained in:
Celes Renata
2026-04-14 19:42:48 +00:00
parent b478022ba3
commit f7a11d14ea
203 changed files with 20155 additions and 97 deletions
+741
View File
@@ -0,0 +1,741 @@
"""Interpolation engine — macro-to-company impact scoring.
Computes per-company macro impact scores by evaluating overlap between
global event classifications and company exposure profiles. Produces
MacroImpactRecord objects that feed into the aggregation engine as
additional weighted signals.
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 3.2
"""
from __future__ import annotations
import json
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.extractor.event_classifier import GlobalEvent
from services.shared.schemas import (
ExposureProfileSchema,
MarketPositionTier,
SeverityLevel,
)
logger = logging.getLogger("interpolation")
# ---------------------------------------------------------------------------
# Default configuration constants
# ---------------------------------------------------------------------------
DEFAULT_CONFIDENCE_THRESHOLD = 0.4
DEFAULT_SHORT_TERM_STALENESS_HOURS = 48
ACCELERATED_DECAY_MULTIPLIER = 0.5 # applied on top of standard recency decay
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Severity weights
SEVERITY_WEIGHTS: dict[str, float] = {
SeverityLevel.CRITICAL.value: 1.0,
SeverityLevel.HIGH.value: 0.75,
SeverityLevel.MODERATE.value: 0.5,
SeverityLevel.LOW.value: 0.25,
}
# Component weights in the scoring formula
GEO_WEIGHT = 0.35
SUPPLY_WEIGHT = 0.25
COMMODITY_WEIGHT = 0.25
SECTOR_WEIGHT = 0.15
# Resilience modifiers for international events
RESILIENCE_MODIFIERS: dict[str, float] = {
MarketPositionTier.GLOBAL_LEADER.value: 0.7,
MarketPositionTier.MULTINATIONAL.value: 0.85,
MarketPositionTier.REGIONAL.value: 1.0,
MarketPositionTier.DOMESTIC.value: 1.2,
}
# Event types that are typically negative
_NEGATIVE_EVENT_TYPES = frozenset({
"supply_disruption",
"cost_increase",
"regulatory_pressure",
"geopolitical_risk",
"trade_barrier",
})
# Event types that can be positive
_POSITIVE_EVENT_TYPES = frozenset({
"demand_shift",
})
# Event types that can go either way
_AMBIGUOUS_EVENT_TYPES = frozenset({
"commodity_shock",
"currency_impact",
})
# Market cap bucket → market position tier mapping for default profiles
_CAP_TO_TIER: dict[str, str] = {
"large_cap": MarketPositionTier.GLOBAL_LEADER.value,
"mid_cap": MarketPositionTier.MULTINATIONAL.value,
"small_cap": MarketPositionTier.REGIONAL.value,
"micro_cap": MarketPositionTier.DOMESTIC.value,
}
# Sector-based default geographic revenue mixes
_SECTOR_DEFAULT_GEO: dict[str, dict[str, float]] = {
"Information Technology": {"US": 0.45, "CN": 0.15, "EU": 0.15, "JP": 0.10, "KR": 0.15},
"Health Care": {"US": 0.50, "EU": 0.25, "JP": 0.10, "CN": 0.15},
"Financials": {"US": 0.55, "EU": 0.20, "GB": 0.15, "JP": 0.10},
"Energy": {"US": 0.30, "SA": 0.20, "RU": 0.15, "CA": 0.15, "AE": 0.20},
"Materials": {"US": 0.25, "CN": 0.25, "AU": 0.20, "BR": 0.15, "IN": 0.15},
"Industrials": {"US": 0.40, "DE": 0.15, "CN": 0.15, "JP": 0.15, "KR": 0.15},
"Consumer Discretionary": {"US": 0.45, "CN": 0.20, "EU": 0.15, "JP": 0.10, "IN": 0.10},
"Consumer Staples": {"US": 0.45, "EU": 0.20, "CN": 0.15, "IN": 0.10, "BR": 0.10},
"Communication Services": {"US": 0.50, "CN": 0.15, "EU": 0.15, "JP": 0.10, "IN": 0.10},
"Utilities": {"US": 0.70, "EU": 0.15, "JP": 0.15},
"Real Estate": {"US": 0.60, "CN": 0.15, "EU": 0.15, "JP": 0.10},
}
_DEFAULT_GEO = {"US": 0.50, "EU": 0.20, "CN": 0.15, "JP": 0.15}
# ---------------------------------------------------------------------------
# MacroImpactRecord dataclass
# ---------------------------------------------------------------------------
@dataclass
class MacroImpactRecord:
"""A computed macro impact score for a specific company-event pair."""
event_id: str = ""
company_id: str = ""
ticker: str = ""
macro_impact_score: float = 0.0 # [0, 1]
impact_direction: str = "neutral" # positive|negative|mixed
contributing_factors: list[str] = field(default_factory=list)
confidence: float = 0.5 # [0, 1]
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
# ---------------------------------------------------------------------------
# Overlap computation functions
# ---------------------------------------------------------------------------
def compute_geographic_overlap(
event_regions: list[str],
revenue_mix: dict[str, float],
) -> float:
"""Compute geographic overlap using revenue percentage weighting.
For each event region that appears in the company's revenue mix,
sum the revenue percentage. Returns a value in [0, 1].
Args:
event_regions: Region codes from the global event.
revenue_mix: Company's geographic_revenue_mix (region -> pct).
Returns:
Sum of revenue percentages for overlapping regions, clamped to [0, 1].
"""
if not event_regions or not revenue_mix:
return 0.0
event_set = {r.upper() for r in event_regions}
overlap = 0.0
for region, pct in revenue_mix.items():
if region.upper() in event_set:
overlap += pct
return min(max(overlap, 0.0), 1.0)
def compute_supply_chain_overlap(
event_regions: list[str],
supply_regions: list[str],
) -> float:
"""Compute supply chain overlap using set intersection ratio.
Returns the fraction of the company's supply chain regions that
overlap with the event's affected regions.
Args:
event_regions: Region codes from the global event.
supply_regions: Company's supply_chain_regions.
Returns:
Intersection ratio in [0, 1]. 0.0 if supply_regions is empty.
"""
if not event_regions or not supply_regions:
return 0.0
event_set = {r.upper() for r in event_regions}
supply_set = {r.upper() for r in supply_regions}
intersection = event_set & supply_set
return len(intersection) / len(supply_set)
def compute_commodity_overlap(
event_commodities: list[str],
company_commodities: list[str],
) -> float:
"""Compute commodity overlap using set intersection ratio.
Returns the fraction of the company's key commodities that overlap
with the event's affected commodities.
Args:
event_commodities: Commodity identifiers from the global event.
company_commodities: Company's key_input_commodities.
Returns:
Intersection ratio in [0, 1]. 0.0 if company_commodities is empty.
"""
if not event_commodities or not company_commodities:
return 0.0
event_set = {c.lower() for c in event_commodities}
company_set = {c.lower() for c in company_commodities}
intersection = event_set & company_set
return len(intersection) / len(company_set)
# ---------------------------------------------------------------------------
# Resilience modifier
# ---------------------------------------------------------------------------
def apply_resilience_modifier(
raw_score: float,
tier: str,
event_is_international: bool = True,
) -> float:
"""Apply a resilience modifier based on market position tier.
For international events, global leaders get a dampening factor (0.7)
while domestic companies get an amplification factor (1.2).
For domestic-only events, no modifier is applied.
Args:
raw_score: The raw impact score before resilience adjustment.
tier: Market position tier value.
event_is_international: Whether the event affects multiple countries.
Returns:
Modified score clamped to [0, 1].
"""
if not event_is_international:
return min(max(raw_score, 0.0), 1.0)
modifier = RESILIENCE_MODIFIERS.get(tier, 1.0)
return min(max(raw_score * modifier, 0.0), 1.0)
# ---------------------------------------------------------------------------
# Impact direction determination
# ---------------------------------------------------------------------------
def _determine_impact_direction(
event_types: list[str],
) -> tuple[str, list[str], list[str]]:
"""Determine impact direction from event types.
Returns:
Tuple of (direction, positive_factors, negative_factors).
"""
positive_factors: list[str] = []
negative_factors: list[str] = []
for et in event_types:
if et in _NEGATIVE_EVENT_TYPES:
negative_factors.append(et)
elif et in _POSITIVE_EVENT_TYPES:
positive_factors.append(et)
elif et in _AMBIGUOUS_EVENT_TYPES:
# Ambiguous types contribute to both sides
positive_factors.append(et)
negative_factors.append(et)
has_positive = len(positive_factors) > 0
has_negative = len(negative_factors) > 0
if has_positive and has_negative:
return "mixed", positive_factors, negative_factors
elif has_positive:
return "positive", positive_factors, negative_factors
elif has_negative:
return "negative", positive_factors, negative_factors
else:
return "negative", positive_factors, negative_factors
# ---------------------------------------------------------------------------
# Core scoring function
# ---------------------------------------------------------------------------
def compute_macro_impact(
event: GlobalEvent,
profile: ExposureProfileSchema,
) -> MacroImpactRecord:
"""Compute the macro impact of a global event on a company.
Scoring formula:
raw_score = severity_weight * (
0.35 * geographic_overlap +
0.25 * supply_chain_overlap +
0.25 * commodity_overlap +
0.15 * sector_match
)
final_score = apply_resilience_modifier(raw_score, tier, is_international)
Args:
event: The classified global event.
profile: The company's exposure profile.
Returns:
A MacroImpactRecord with the computed score and metadata.
"""
now = datetime.now(timezone.utc)
# Compute overlaps
geo_overlap = compute_geographic_overlap(
event.affected_regions,
profile.geographic_revenue_mix,
)
supply_overlap = compute_supply_chain_overlap(
event.affected_regions,
profile.supply_chain_regions,
)
commodity_overlap = compute_commodity_overlap(
event.affected_commodities,
profile.key_input_commodities,
)
# Sector match: 1.0 if any event sector matches the company's sector
# We check against the profile's regulatory_jurisdictions as a proxy,
# but the real sector comes from the company data. For now, we use
# a simple heuristic: check if any affected_sectors appear in the
# profile's geographic_revenue_mix keys or supply_chain_regions.
# The actual sector is not stored in ExposureProfileSchema, so we
# check if any event sectors match. This will be 0.0 unless the
# caller provides sector info through contributing_factors.
sector_match = 0.0
# We'll compute sector_match based on event sectors — the caller
# should ensure the profile has relevant sector info. For the
# default implementation, we always set sector_match to 0.0 here
# and let the caller override if needed.
# Check zero-overlap case
contributing = []
if geo_overlap > 0:
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
if supply_overlap > 0:
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
if commodity_overlap > 0:
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
if total_overlap == 0.0:
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=0.0,
impact_direction="neutral",
contributing_factors=[],
confidence=0.0,
computed_at=now,
)
# Severity weight
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
# Raw score
raw_score = severity_weight * (
GEO_WEIGHT * geo_overlap
+ SUPPLY_WEIGHT * supply_overlap
+ COMMODITY_WEIGHT * commodity_overlap
+ SECTOR_WEIGHT * sector_match
)
# Determine if event is international (affects multiple regions)
is_international = len(event.affected_regions) > 1
# Apply resilience modifier
tier = profile.market_position_tier
if isinstance(tier, MarketPositionTier):
tier = tier.value
final_score = apply_resilience_modifier(raw_score, tier, is_international)
# Determine impact direction
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
# Build contributing factors list
all_factors = list(contributing)
if pos_factors:
all_factors.append(f"positive_types:{','.join(pos_factors)}")
if neg_factors:
all_factors.append(f"negative_types:{','.join(neg_factors)}")
# Confidence: combine event confidence with overlap strength
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=round(min(final_score, 1.0), 6),
impact_direction=direction,
contributing_factors=all_factors,
confidence=round(confidence, 6),
computed_at=now,
)
def compute_macro_impact_with_sector(
event: GlobalEvent,
profile: ExposureProfileSchema,
company_sector: str = "",
) -> MacroImpactRecord:
"""Compute macro impact with explicit sector matching.
Like compute_macro_impact but accepts a company_sector parameter
for proper sector_match computation.
Args:
event: The classified global event.
profile: The company's exposure profile.
company_sector: The company's GICS sector name.
Returns:
A MacroImpactRecord with the computed score and metadata.
"""
now = datetime.now(timezone.utc)
# Compute overlaps
geo_overlap = compute_geographic_overlap(
event.affected_regions,
profile.geographic_revenue_mix,
)
supply_overlap = compute_supply_chain_overlap(
event.affected_regions,
profile.supply_chain_regions,
)
commodity_overlap = compute_commodity_overlap(
event.affected_commodities,
profile.key_input_commodities,
)
# Sector match
sector_match = 0.0
if company_sector and event.affected_sectors:
company_sector_lower = company_sector.lower().strip()
for es in event.affected_sectors:
if es.lower().strip() == company_sector_lower:
sector_match = 1.0
break
# Contributing factors
contributing: list[str] = []
if geo_overlap > 0:
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
if supply_overlap > 0:
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
if commodity_overlap > 0:
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
if sector_match > 0:
contributing.append(f"sector_match:{company_sector}")
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
if total_overlap == 0.0:
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=0.0,
impact_direction="neutral",
contributing_factors=[],
confidence=0.0,
computed_at=now,
)
# Severity weight
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
# Raw score
raw_score = severity_weight * (
GEO_WEIGHT * geo_overlap
+ SUPPLY_WEIGHT * supply_overlap
+ COMMODITY_WEIGHT * commodity_overlap
+ SECTOR_WEIGHT * sector_match
)
# International check
is_international = len(event.affected_regions) > 1
# Resilience modifier
tier = profile.market_position_tier
if isinstance(tier, MarketPositionTier):
tier = tier.value
final_score = apply_resilience_modifier(raw_score, tier, is_international)
# Direction
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
all_factors = list(contributing)
if pos_factors:
all_factors.append(f"positive_types:{','.join(pos_factors)}")
if neg_factors:
all_factors.append(f"negative_types:{','.join(neg_factors)}")
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=round(min(final_score, 1.0), 6),
impact_direction=direction,
contributing_factors=all_factors,
confidence=round(confidence, 6),
computed_at=now,
)
# ---------------------------------------------------------------------------
# Default profile builder
# ---------------------------------------------------------------------------
def build_default_profile(
sector: str,
industry: str,
market_cap_bucket: str,
) -> ExposureProfileSchema:
"""Build a default ExposureProfile for companies without manual profiles.
Uses sector-based geographic revenue defaults and maps market_cap_bucket
to market_position_tier:
large_cap → global_leader
mid_cap → multinational
small_cap → regional
micro_cap → domestic
Args:
sector: GICS sector name.
industry: Industry name (used for commodity defaults).
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
Returns:
An ExposureProfileSchema with source='inferred'.
"""
tier = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
geo_mix = _SECTOR_DEFAULT_GEO.get(sector, _DEFAULT_GEO)
# Derive supply chain regions from geo mix keys
supply_regions = list(geo_mix.keys())
# Derive commodities from sector/industry
commodities = _infer_commodities(sector, industry)
# Export dependency based on tier
export_pct = {
MarketPositionTier.GLOBAL_LEADER.value: 0.5,
MarketPositionTier.MULTINATIONAL.value: 0.35,
MarketPositionTier.REGIONAL.value: 0.15,
MarketPositionTier.DOMESTIC.value: 0.05,
}.get(tier, 0.15)
return ExposureProfileSchema(
company_id="",
geographic_revenue_mix=dict(geo_mix),
supply_chain_regions=supply_regions,
key_input_commodities=commodities,
regulatory_jurisdictions=list(geo_mix.keys())[:3],
market_position_tier=MarketPositionTier(tier),
export_dependency_pct=export_pct,
source="inferred",
confidence=0.5,
version=1,
)
def _infer_commodities(sector: str, industry: str) -> list[str]:
"""Infer key input commodities from sector and industry."""
sector_commodities: dict[str, list[str]] = {
"Energy": ["crude_oil", "natural_gas"],
"Materials": ["copper", "steel", "lithium"],
"Industrials": ["steel", "copper"],
"Information Technology": ["semiconductors", "lithium"],
"Consumer Staples": ["wheat", "corn"],
"Consumer Discretionary": ["steel", "semiconductors"],
"Health Care": [],
"Financials": [],
"Communication Services": [],
"Utilities": ["natural_gas"],
"Real Estate": ["steel"],
}
return sector_commodities.get(sector, [])
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
async def persist_macro_impact_record(
pool: asyncpg.Pool,
record: MacroImpactRecord,
) -> str:
"""Persist a MacroImpactRecord to the macro_impact_records table.
Returns the row UUID.
"""
row_id = await pool.fetchval(
"""INSERT INTO macro_impact_records
(event_id, company_id, ticker, macro_impact_score,
impact_direction, contributing_factors, confidence, computed_at)
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6::jsonb, $7, $8)
RETURNING id""",
record.event_id,
record.company_id,
record.ticker,
record.macro_impact_score,
record.impact_direction,
json.dumps(record.contributing_factors),
record.confidence,
record.computed_at,
)
logger.info(
"Persisted macro impact record for event=%s company=%s score=%.4f direction=%s",
record.event_id,
record.company_id,
record.macro_impact_score,
record.impact_direction,
)
return str(row_id)
async def persist_macro_impact_records(
pool: asyncpg.Pool,
records: list[MacroImpactRecord],
) -> list[str]:
"""Persist multiple MacroImpactRecords. Returns list of row UUIDs."""
ids: list[str] = []
for record in records:
if record.macro_impact_score > 0.0:
row_id = await persist_macro_impact_record(pool, record)
ids.append(row_id)
return ids
# ---------------------------------------------------------------------------
# Low-confidence event exclusion (Requirements: 10.1)
# ---------------------------------------------------------------------------
def filter_low_confidence_events(
events: list[GlobalEvent],
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
) -> list[GlobalEvent]:
"""Filter out events with confidence below the configurable threshold.
Events with confidence below the threshold are excluded from macro
impact computation and the exclusion reason is logged.
Args:
events: List of GlobalEvent classifications.
confidence_threshold: Minimum confidence for inclusion (default 0.4).
Returns:
List of events that pass the confidence threshold.
Requirements: 10.1
"""
included: list[GlobalEvent] = []
for event in events:
if event.confidence < confidence_threshold:
logger.info(
"Excluding low-confidence event %s: confidence=%.3f < threshold=%.3f",
event.event_id,
event.confidence,
confidence_threshold,
)
else:
included.append(event)
return included
# ---------------------------------------------------------------------------
# Accelerated decay for stale short-term events (Requirements: 10.2)
# ---------------------------------------------------------------------------
def compute_standard_recency_decay(
age_hours: float,
half_life_hours: float = 168.0,
) -> float:
"""Compute standard exponential recency decay.
Args:
age_hours: Age of the event in hours.
half_life_hours: Half-life for the decay function (default 7 days).
Returns:
Decay factor in (0, 1].
"""
import math
if age_hours <= 0:
return 1.0
return math.exp(-0.693 * age_hours / half_life_hours)
def apply_accelerated_decay(
age_hours: float,
estimated_duration: str,
staleness_hours: float = DEFAULT_SHORT_TERM_STALENESS_HOURS,
half_life_hours: float = 168.0,
) -> float:
"""Apply accelerated decay for stale short-term events.
For short_term events older than staleness_hours (default 48h),
the effective weight is strictly less than standard recency decay.
For non-short-term events or events within the staleness window,
standard recency decay is applied.
Args:
age_hours: Age of the event in hours.
estimated_duration: Event's estimated_duration field.
staleness_hours: Hours after which short-term events get accelerated decay.
half_life_hours: Half-life for the standard decay function.
Returns:
Effective signal weight in (0, 1].
Requirements: 10.2
"""
standard_decay = compute_standard_recency_decay(age_hours, half_life_hours)
if estimated_duration == "short_term" and age_hours > staleness_hours:
# Apply accelerated decay: multiply standard decay by a factor < 1
accelerated = standard_decay * ACCELERATED_DECAY_MULTIPLIER
logger.debug(
"Accelerated decay for short_term event: age=%.1fh, "
"standard=%.4f, accelerated=%.4f",
age_hours, standard_decay, accelerated,
)
return accelerated
return standard_decay
+125 -3
View File
@@ -1,4 +1,12 @@
"""Aggregation worker entrypoint - polls Redis for aggregation jobs."""
"""Aggregation worker entrypoint - polls Redis for aggregation jobs.
After computing trend summaries for a ticker, the worker also triggers
competitive signal propagation for the ticker's competitors when the
competitive layer is enabled. This ensures that document intelligence
for one company produces competitive signals for related companies.
Requirements: 4.1, 5.1, 9.4
"""
from __future__ import annotations
import asyncio
@@ -8,8 +16,9 @@ import logging
import asyncpg
import redis.asyncio as aioredis
from services.aggregation.worker import aggregate_company
from services.shared.config import load_config
from services.aggregation.signal_propagation import propagate_signals
from services.aggregation.worker import aggregate_company, fetch_competitive_enabled
from services.shared.config import CompetitiveConfig, load_config
from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import (
QUEUE_AGGREGATION,
@@ -20,6 +29,92 @@ from services.shared.redis_keys import (
logger = logging.getLogger("aggregation_main")
# ---------------------------------------------------------------------------
# Query to fetch recent document intelligence records for a ticker.
# Used to trigger signal propagation after aggregation completes.
# ---------------------------------------------------------------------------
_RECENT_INTELLIGENCE_QUERY = """
SELECT
di.document_id,
dir.catalyst_type,
dir.impact_score
FROM document_impact_records dir
JOIN document_intelligence di ON di.id = dir.intelligence_id
JOIN documents d ON d.id = di.document_id
WHERE dir.ticker = $1
AND di.validation_status = 'valid'
AND d.status != 'rejected'
ORDER BY d.published_at DESC
LIMIT 50
"""
# Track consecutive propagation failures for alerting (Requirement 9.4)
_propagation_consecutive_failures = 0
async def _trigger_signal_propagation(
pool: asyncpg.Pool,
ticker: str,
competitive_config: CompetitiveConfig,
) -> int:
"""Trigger competitive signal propagation for a ticker's recent documents.
Fetches recent document intelligence records for the ticker and calls
propagate_signals for each, producing competitive signals for the
ticker's competitors.
Returns the total number of competitive signals produced.
"""
global _propagation_consecutive_failures
rows = await pool.fetch(_RECENT_INTELLIGENCE_QUERY, ticker)
if not rows:
return 0
total_signals = 0
for row in rows:
document_id = str(row["document_id"])
catalyst_type = row["catalyst_type"] or "other"
impact_score = float(row["impact_score"] or 0.0)
if impact_score <= 0.0:
continue
try:
records = await propagate_signals(
pool=pool,
ticker=ticker,
catalyst_type=catalyst_type,
impact_score=impact_score,
document_id=document_id,
config=competitive_config,
)
total_signals += len(records)
# Reset failure counter on success
_propagation_consecutive_failures = 0
except Exception:
_propagation_consecutive_failures += 1
logger.exception(
"Signal propagation failed for %s doc %s/%s",
ticker, document_id, catalyst_type,
)
if _propagation_consecutive_failures >= competitive_config.propagation_failure_threshold:
logger.critical(
"ALERT: Sustained signal propagation failures (%d consecutive). "
"Continuing with company-specific + macro signals only. "
"Operator action required.",
_propagation_consecutive_failures,
)
# Stop trying propagation for this ticker after threshold
break
return total_signals
async def main() -> None:
config = load_config()
setup_logging("aggregation", level=config.log_level, json_output=config.json_logs)
@@ -28,6 +123,7 @@ async def main() -> None:
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_AGGREGATION)
rec_queue = queue_key(QUEUE_RECOMMENDATION)
competitive_config = config.competitive
logger.info("Aggregation worker started, polling %s", queue)
try:
@@ -49,6 +145,32 @@ async def main() -> None:
ticker, len(summaries),
)
# Trigger competitive signal propagation after aggregation
# (Requirement 4.1): When new document intelligence is
# produced for a company, propagate signals to competitors.
# Check toggle state from DB (same pattern as macro toggle).
competitive_enabled = competitive_config.competitive_enabled
db_toggle = await fetch_competitive_enabled(pool)
if db_toggle is not None:
competitive_enabled = db_toggle
if competitive_enabled:
try:
sig_count = await _trigger_signal_propagation(
pool, ticker, competitive_config,
)
if sig_count > 0:
logger.info(
"Propagated %d competitive signals for %s",
sig_count, ticker,
)
except Exception:
logger.exception(
"Signal propagation failed for %s"
"continuing with company+macro signals only",
ticker,
)
# Enqueue recommendation job for each window that produced a trend
for summary in summaries:
if summary.trend_strength > 0:
+414
View File
@@ -0,0 +1,414 @@
"""Historical pattern mining for competitive intelligence.
Queries document_impact_records joined with trend_windows to find how
similar catalyst types resolved historically for a company or its
competitors. Produces HistoricalPattern objects consumed by the signal
propagation engine and the aggregation worker.
Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 11.1, 11.2, 11.3, 11.5
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Optional
import asyncpg
from services.shared.config import CompetitiveConfig
from services.shared.schemas import MAJOR_DECISION_CATALYSTS
logger = logging.getLogger(__name__)
DEFAULT_HORIZONS = ["1d", "7d", "30d"]
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass
class HistoricalPattern:
"""Statistical summary of how a catalyst type resolved historically."""
source_ticker: str
target_ticker: str
catalyst_type: str
time_horizon: str # 1d | 7d | 30d
sample_count: int
bullish_pct: float # [0, 1]
bearish_pct: float # [0, 1]
avg_strength: float # [0, 1]
avg_time_to_resolution: float # days
pattern_confidence: float # [0, 1]
data_start: datetime
data_end: datetime
tier: str # major_corporate_decision | routine_signal
insufficient_data: bool
# ---------------------------------------------------------------------------
# Catalyst tier classification (Req 11.1)
# ---------------------------------------------------------------------------
def classify_catalyst_tier(catalyst_type: str) -> str:
"""Deterministic mapping of catalyst_type to tier.
Returns ``"major_corporate_decision"`` for catalyst types in
MAJOR_DECISION_CATALYSTS, otherwise ``"routine_signal"``.
"""
if catalyst_type in MAJOR_DECISION_CATALYSTS:
return "major_corporate_decision"
return "routine_signal"
# ---------------------------------------------------------------------------
# Pattern confidence (Req 3.3, 11.2)
# ---------------------------------------------------------------------------
def compute_pattern_confidence(
sample_count: int,
outcome_consistency: float,
data_recency_days: float,
tier: str,
config: Optional[CompetitiveConfig] = None,
) -> float:
"""Compute pattern confidence score in [0, 1].
Formula:
sample_factor * 0.4 + consistency * 0.4 + recency_factor * 0.2
With a 1.3× multiplier for ``major_corporate_decision`` tier,
insufficient-data cap, and staleness decay.
"""
cfg = config or CompetitiveConfig()
# --- component factors ---
sample_factor = min(sample_count / 20.0, 1.0)
consistency = outcome_consistency # already max(bullish_pct, bearish_pct)
if data_recency_days <= cfg.staleness_recent_days:
recency_factor = 1.0
elif data_recency_days <= cfg.staleness_window_days:
recency_factor = 0.7
else:
recency_factor = 0.4
confidence = sample_factor * 0.4 + consistency * 0.4 + recency_factor * 0.2
# Major-decision multiplier (Req 11.2)
if tier == "major_corporate_decision":
confidence *= cfg.major_decision_weight_multiplier
# Clamp to [0, 1]
confidence = min(max(confidence, 0.0), 1.0)
# Insufficient data cap (Req 3.4)
if sample_count < cfg.min_pattern_samples:
confidence = min(confidence, 0.25)
# Staleness decay (Req 9.2)
if data_recency_days > cfg.staleness_window_days:
confidence *= cfg.staleness_decay_penalty
return confidence
# ---------------------------------------------------------------------------
# Lookback helper
# ---------------------------------------------------------------------------
def _lookback_days(tier: str, config: Optional[CompetitiveConfig] = None) -> int:
"""Return the lookback window in days for the given tier."""
cfg = config or CompetitiveConfig()
if tier == "major_corporate_decision":
return cfg.major_decision_lookback_days
return cfg.routine_lookback_days
# ---------------------------------------------------------------------------
# SQL: self-company pattern query
# ---------------------------------------------------------------------------
_SELF_PATTERN_QUERY = """
WITH matched_docs AS (
SELECT
dir.id AS dir_id,
d.published_at,
dir.sentiment
FROM document_impact_records dir
JOIN document_intelligence di ON di.id = dir.intelligence_id
JOIN documents d ON d.id = di.document_id
WHERE dir.ticker = $1
AND dir.catalyst_type = $2
AND di.validation_status = 'valid'
AND d.status != 'rejected'
AND d.published_at >= $3
AND d.published_at <= $4
)
SELECT
md.dir_id,
md.published_at,
md.sentiment,
tw.trend_direction,
tw.trend_strength,
tw.generated_at,
tw."window" AS tw_window
FROM matched_docs md
JOIN trend_windows tw
ON tw.entity_type = 'company'
AND tw.entity_id = $1
AND tw."window" = $5
AND tw.generated_at >= md.published_at
AND tw.generated_at <= md.published_at + $6::interval
ORDER BY md.published_at DESC
"""
# ---------------------------------------------------------------------------
# SQL: cross-company pattern query
# ---------------------------------------------------------------------------
_CROSS_PATTERN_QUERY = """
WITH matched_docs AS (
SELECT
dir.id AS dir_id,
d.published_at,
dir.sentiment
FROM document_impact_records dir
JOIN document_intelligence di ON di.id = dir.intelligence_id
JOIN documents d ON d.id = di.document_id
WHERE dir.ticker = $1
AND dir.catalyst_type = $2
AND di.validation_status = 'valid'
AND d.status != 'rejected'
AND d.published_at >= $3
AND d.published_at <= $4
)
SELECT
md.dir_id,
md.published_at,
md.sentiment,
tw.trend_direction,
tw.trend_strength,
tw.generated_at,
tw."window" AS tw_window
FROM matched_docs md
JOIN trend_windows tw
ON tw.entity_type = 'company'
AND tw.entity_id = $5
AND tw."window" = $6
AND tw.generated_at >= md.published_at
AND tw.generated_at <= md.published_at + $7::interval
ORDER BY md.published_at DESC
"""
# ---------------------------------------------------------------------------
# Horizon → interval mapping
# ---------------------------------------------------------------------------
_HORIZON_INTERVALS: dict[str, str] = {
"1d": "1 day",
"7d": "7 days",
"30d": "30 days",
}
# ---------------------------------------------------------------------------
# Build HistoricalPattern from query rows
# ---------------------------------------------------------------------------
def _build_pattern(
rows: list[asyncpg.Record],
source_ticker: str,
target_ticker: str,
catalyst_type: str,
horizon: str,
tier: str,
config: Optional[CompetitiveConfig] = None,
) -> Optional[HistoricalPattern]:
"""Aggregate query rows into a single HistoricalPattern."""
if not rows:
return None
# De-duplicate by dir_id — keep the first (closest) trend_window per doc
seen: set[str] = set()
unique_rows: list[asyncpg.Record] = []
for r in rows:
rid = str(r["dir_id"])
if rid not in seen:
seen.add(rid)
unique_rows.append(r)
sample_count = len(unique_rows)
bullish = sum(1 for r in unique_rows if r["trend_direction"] == "bullish")
bearish = sum(1 for r in unique_rows if r["trend_direction"] == "bearish")
bullish_pct = bullish / sample_count
bearish_pct = bearish / sample_count
strengths = [float(r["trend_strength"]) for r in unique_rows if r["trend_strength"] is not None]
avg_strength = sum(strengths) / len(strengths) if strengths else 0.0
# avg_time_to_resolution: average days between published_at and generated_at
resolutions: list[float] = []
for r in unique_rows:
pub = r["published_at"]
gen = r["generated_at"]
if pub and gen:
delta = (gen - pub).total_seconds() / 86400.0
resolutions.append(max(delta, 0.0))
avg_time_to_resolution = sum(resolutions) / len(resolutions) if resolutions else 0.0
# Date range
published_dates = [r["published_at"] for r in unique_rows if r["published_at"] is not None]
data_start = min(published_dates)
data_end = max(published_dates)
# Recency: days since the most recent data point
now = datetime.now(timezone.utc)
data_recency_days = (now - data_end).total_seconds() / 86400.0 if data_end else 999.0
outcome_consistency = max(bullish_pct, bearish_pct)
confidence = compute_pattern_confidence(
sample_count, outcome_consistency, data_recency_days, tier, config,
)
insufficient_data = sample_count < (config or CompetitiveConfig()).min_pattern_samples
return HistoricalPattern(
source_ticker=source_ticker,
target_ticker=target_ticker,
catalyst_type=catalyst_type,
time_horizon=horizon,
sample_count=sample_count,
bullish_pct=bullish_pct,
bearish_pct=bearish_pct,
avg_strength=min(max(avg_strength, 0.0), 1.0),
avg_time_to_resolution=avg_time_to_resolution,
pattern_confidence=confidence,
data_start=data_start,
data_end=data_end,
tier=tier,
insufficient_data=insufficient_data,
)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def find_self_patterns(
pool: asyncpg.Pool,
ticker: str,
catalyst_type: str,
horizons: Optional[list[str]] = None,
config: Optional[CompetitiveConfig] = None,
) -> list[HistoricalPattern]:
"""Find historical patterns for the same company-catalyst pair.
Queries document_impact_records joined with trend_windows for the
given ticker and catalyst_type across configurable time horizons.
Requirements: 3.1, 3.2, 3.5, 11.3
"""
cfg = config or CompetitiveConfig()
horizons = horizons or DEFAULT_HORIZONS
tier = classify_catalyst_tier(catalyst_type)
lookback = _lookback_days(tier, cfg)
now = datetime.now(timezone.utc)
cutoff = now - timedelta(days=lookback)
patterns: list[HistoricalPattern] = []
async with pool.acquire() as conn:
for horizon in horizons:
interval = _HORIZON_INTERVALS.get(horizon)
if interval is None:
logger.warning("Unknown horizon %s, skipping", horizon)
continue
try:
rows = await conn.fetch(
_SELF_PATTERN_QUERY,
ticker, # $1
catalyst_type, # $2
cutoff, # $3
now, # $4
horizon, # $5
interval, # $6
)
except Exception:
logger.exception(
"Error querying self-patterns for %s/%s/%s",
ticker, catalyst_type, horizon,
)
continue
pattern = _build_pattern(
rows, ticker, ticker, catalyst_type, horizon, tier, cfg,
)
if pattern is not None:
patterns.append(pattern)
return patterns
async def find_cross_company_patterns(
pool: asyncpg.Pool,
source_ticker: str,
target_ticker: str,
catalyst_type: str,
horizons: Optional[list[str]] = None,
config: Optional[CompetitiveConfig] = None,
) -> list[HistoricalPattern]:
"""Find cross-company historical patterns.
Queries documents about *source_ticker* with the given catalyst_type,
then looks at trend_windows for *target_ticker* within each horizon
after the document was published.
Requirements: 4.2, 11.5
"""
cfg = config or CompetitiveConfig()
horizons = horizons or DEFAULT_HORIZONS
tier = classify_catalyst_tier(catalyst_type)
lookback = _lookback_days(tier, cfg)
now = datetime.now(timezone.utc)
cutoff = now - timedelta(days=lookback)
patterns: list[HistoricalPattern] = []
async with pool.acquire() as conn:
for horizon in horizons:
interval = _HORIZON_INTERVALS.get(horizon)
if interval is None:
logger.warning("Unknown horizon %s, skipping", horizon)
continue
try:
rows = await conn.fetch(
_CROSS_PATTERN_QUERY,
source_ticker, # $1
catalyst_type, # $2
cutoff, # $3
now, # $4
target_ticker, # $5
horizon, # $6
interval, # $7
)
except Exception:
logger.exception(
"Error querying cross-patterns for %s%s/%s/%s",
source_ticker, target_ticker, catalyst_type, horizon,
)
continue
pattern = _build_pattern(
rows, source_ticker, target_ticker, catalyst_type,
horizon, tier, cfg,
)
if pattern is not None:
patterns.append(pattern)
return patterns
+416
View File
@@ -0,0 +1,416 @@
"""Trend projection module — forward-looking trend estimates.
Computes TrendProjection objects by combining current trend momentum,
macro signal decay trajectories, and upcoming catalyst outlook.
Projections are persisted alongside trend_window records.
Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.9
"""
from __future__ import annotations
import json
import logging
import math
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.shared.schemas import TrendDirection, TrendSummary
logger = logging.getLogger("projection")
# ---------------------------------------------------------------------------
# TrendProjection dataclass
# ---------------------------------------------------------------------------
VALID_DIRECTIONS = {"bullish", "bearish", "mixed", "neutral"}
VALID_HORIZONS = {"1d", "7d", "30d"}
# Default low-confidence threshold
DEFAULT_CONFIDENCE_THRESHOLD = 0.3
# Macro signal decay half-lives (in days) by estimated_duration
DECAY_HALF_LIFE_DAYS: dict[str, float] = {
"short_term": 1.0, # halve impact per day
"medium_term": 7.0, # halve impact per week
"long_term": 30.0, # halve impact per month
}
@dataclass
class TrendProjection:
"""Forward-looking trend projection for a company."""
projected_direction: str = "neutral" # bullish|bearish|mixed|neutral
projected_strength: float = 0.5 # [0, 1]
projected_confidence: float = 0.5 # [0, 1]
projection_horizon: str = "7d" # 1d|7d|30d
driving_factors: list[str] = field(default_factory=list)
macro_contribution_pct: float = 0.0 # [0, 1]
diverges_from_current: bool = False
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
low_confidence: bool = False
# ---------------------------------------------------------------------------
# Macro impact row type (lightweight, avoids circular import with worker)
# ---------------------------------------------------------------------------
@dataclass
class MacroEventInfo:
"""Minimal macro event info needed for projection computation."""
event_id: str = ""
macro_impact_score: float = 0.0
impact_direction: str = "neutral"
confidence: float = 0.5
estimated_duration: str = "short_term"
severity: str = "low"
event_age_hours: float = 0.0 # hours since event publication
# ---------------------------------------------------------------------------
# Projection horizon mapping from trend window
# ---------------------------------------------------------------------------
_WINDOW_TO_HORIZON: dict[str, str] = {
"intraday": "1d",
"1d": "1d",
"7d": "7d",
"30d": "30d",
"90d": "30d",
}
# ---------------------------------------------------------------------------
# Momentum computation
# ---------------------------------------------------------------------------
def compute_trend_momentum(
current_strength: float,
current_direction: str,
previous_strength: float | None = None,
previous_direction: str | None = None,
) -> float:
"""Compute trend momentum as rate of change in signed strength.
Returns a value in [-1, 1] representing the momentum:
- Positive = strengthening bullish or weakening bearish
- Negative = strengthening bearish or weakening bullish
- Zero = no change or no previous data
When no previous data is available, uses a simple heuristic based
on current strength and direction.
"""
dir_sign = _direction_sign(current_direction)
if previous_strength is None or previous_direction is None:
# Heuristic: assume momentum proportional to current signed strength
return round(dir_sign * current_strength * 0.5, 6)
prev_sign = _direction_sign(previous_direction)
current_signed = dir_sign * current_strength
previous_signed = prev_sign * previous_strength
momentum = current_signed - previous_signed
return round(max(-1.0, min(1.0, momentum)), 6)
def _direction_sign(direction: str) -> float:
"""Map direction to a sign multiplier."""
if direction == "bullish":
return 1.0
elif direction == "bearish":
return -1.0
return 0.0
# ---------------------------------------------------------------------------
# Macro signal decay projection
# ---------------------------------------------------------------------------
_SEVERITY_WEIGHT: dict[str, float] = {
"critical": 1.0,
"high": 0.75,
"moderate": 0.5,
"low": 0.25,
}
def project_macro_decay(
events: list[MacroEventInfo],
horizon_days: float,
) -> tuple[float, str]:
"""Project the aggregate macro signal after decay over the horizon.
For each active macro event, compute the projected remaining impact
using exponential decay based on estimated_duration:
- short_term: half-life = 1 day
- medium_term: half-life = 7 days
- long_term: half-life = 30 days
Returns:
(projected_macro_strength, projected_macro_direction)
where strength is in [0, 1] and direction is bullish|bearish|mixed|neutral.
"""
if not events:
return 0.0, "neutral"
positive_weight = 0.0
negative_weight = 0.0
for ev in events:
half_life = DECAY_HALF_LIFE_DAYS.get(ev.estimated_duration, 7.0)
# Current age in days
current_age_days = ev.event_age_hours / 24.0
# Projected age at end of horizon
future_age_days = current_age_days + horizon_days
# Decay factor: ratio of future impact to current impact
if half_life > 0:
current_factor = math.pow(2.0, -current_age_days / half_life)
future_factor = math.pow(2.0, -future_age_days / half_life)
else:
current_factor = 0.0
future_factor = 0.0
severity_w = _SEVERITY_WEIGHT.get(ev.severity, 0.25)
projected_impact = ev.macro_impact_score * future_factor * severity_w
if ev.impact_direction == "positive":
positive_weight += projected_impact
elif ev.impact_direction == "negative":
negative_weight += projected_impact
else:
# mixed/neutral: split evenly
positive_weight += projected_impact * 0.5
negative_weight += projected_impact * 0.5
total = positive_weight + negative_weight
if total == 0.0:
return 0.0, "neutral"
strength = min(total, 1.0)
if positive_weight > negative_weight * 1.2:
direction = "bullish"
elif negative_weight > positive_weight * 1.2:
direction = "bearish"
elif positive_weight > 0 and negative_weight > 0:
direction = "mixed"
else:
direction = "neutral"
return round(strength, 6), direction
# ---------------------------------------------------------------------------
# Horizon days mapping
# ---------------------------------------------------------------------------
_HORIZON_DAYS: dict[str, float] = {
"1d": 1.0,
"7d": 7.0,
"30d": 30.0,
}
# ---------------------------------------------------------------------------
# Core projection computation
# ---------------------------------------------------------------------------
def compute_projection(
summary: TrendSummary,
macro_events: list[MacroEventInfo] | None = None,
macro_enabled: bool = True,
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
previous_strength: float | None = None,
previous_direction: str | None = None,
upcoming_catalysts: list[str] | None = None,
) -> TrendProjection:
"""Compute a forward-looking trend projection.
Combines:
1. Trend momentum (rate of change in strength)
2. Macro signal decay projection
3. Upcoming catalyst outlook
4. Current trend baseline
Args:
summary: The current trend summary.
macro_events: Active macro events with their info.
macro_enabled: Whether the macro layer is enabled.
confidence_threshold: Below this, mark as low_confidence.
previous_strength: Previous window's trend strength (optional).
previous_direction: Previous window's trend direction (optional).
upcoming_catalysts: Known upcoming catalysts from doc intelligence.
Returns:
A TrendProjection with projected direction, strength, and confidence.
"""
now = datetime.now(timezone.utc)
current_dir = summary.trend_direction.value
current_strength = summary.trend_strength
current_confidence = summary.confidence
horizon = _WINDOW_TO_HORIZON.get(summary.window.value, "7d")
horizon_days = _HORIZON_DAYS.get(horizon, 7.0)
driving_factors: list[str] = []
# 1. Compute trend momentum
momentum = compute_trend_momentum(
current_strength, current_dir,
previous_strength, previous_direction,
)
if abs(momentum) > 0.05:
if momentum > 0:
driving_factors.append(f"Positive momentum ({momentum:+.3f}) in recent trend strength")
else:
driving_factors.append(f"Negative momentum ({momentum:+.3f}) in recent trend strength")
# 2. Project macro signal decay
macro_strength = 0.0
macro_direction = "neutral"
macro_contribution = 0.0
if macro_enabled and macro_events:
macro_strength, macro_direction = project_macro_decay(macro_events, horizon_days)
if macro_strength > 0:
driving_factors.append(
f"Macro signals project {macro_direction} impact "
f"(strength {macro_strength:.3f}) over {horizon}"
)
# 3. Factor in upcoming catalysts
catalysts = upcoming_catalysts or []
for catalyst in catalysts[:3]: # limit to top 3
driving_factors.append(f"Upcoming catalyst: {catalyst}")
catalyst_boost = min(len(catalysts) * 0.02, 0.1) # small boost per catalyst
# 4. Combine into projected direction/strength/confidence
# Momentum-based projection of company-specific trend
momentum_projected_signed = _direction_sign(current_dir) * current_strength + momentum * 0.5
momentum_projected_strength = min(abs(momentum_projected_signed), 1.0)
if macro_enabled and macro_strength > 0:
# Blend company momentum with macro trajectory
macro_weight = min(macro_strength * 0.4, 0.4)
company_weight = 1.0 - macro_weight
macro_signed = _direction_sign(macro_direction) * macro_strength
blended_signed = (
company_weight * momentum_projected_signed
+ macro_weight * macro_signed
)
projected_strength = round(min(abs(blended_signed) + catalyst_boost, 1.0), 6)
macro_contribution = round(macro_weight, 6)
# Determine projected direction from blended signal
projected_direction = _signed_to_direction(blended_signed)
else:
# Company-only projection
projected_strength = round(min(momentum_projected_strength + catalyst_boost, 1.0), 6)
projected_direction = _signed_to_direction(momentum_projected_signed)
# Compute projected confidence
base_confidence = current_confidence * 0.8 # projection inherently less certain
if macro_enabled and macro_strength > 0:
# Macro data adds information → slight confidence boost
macro_conf_boost = min(macro_strength * 0.15, 0.1)
projected_confidence = round(min(base_confidence + macro_conf_boost, 1.0), 6)
else:
# Without macro data, reduce confidence further
if not macro_enabled:
projected_confidence = round(base_confidence * 0.85, 6)
else:
projected_confidence = round(base_confidence, 6)
# Ensure driving_factors is never empty
if not driving_factors:
driving_factors.append(f"Baseline trend continuation: {current_dir} at strength {current_strength:.3f}")
# 5. Flag divergence
diverges = projected_direction != current_dir
if diverges:
driving_factors.append(
f"DIVERGENCE: Current trend is {current_dir}, "
f"projection is {projected_direction}"
)
# Mark low confidence
is_low_confidence = projected_confidence < confidence_threshold
return TrendProjection(
projected_direction=projected_direction,
projected_strength=projected_strength,
projected_confidence=projected_confidence,
projection_horizon=horizon,
driving_factors=driving_factors,
macro_contribution_pct=macro_contribution,
diverges_from_current=diverges,
computed_at=now,
low_confidence=is_low_confidence,
)
def _signed_to_direction(signed_value: float) -> str:
"""Convert a signed strength value to a direction string."""
if signed_value > 0.1:
return "bullish"
elif signed_value < -0.1:
return "bearish"
elif abs(signed_value) > 0.02:
return "mixed"
return "neutral"
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
_INSERT_PROJECTION = """
INSERT INTO trend_projections (
trend_window_id, projected_direction, projected_strength,
projected_confidence, projection_horizon, driving_factors,
macro_contribution_pct, diverges_from_current, computed_at
) VALUES (
$1::uuid, $2, $3, $4, $5, $6::jsonb, $7, $8, $9
)
RETURNING id
"""
async def persist_trend_projection(
pool: asyncpg.Pool,
trend_window_id: str,
projection: TrendProjection,
) -> str:
"""Persist a TrendProjection to the trend_projections table.
Returns the row UUID.
"""
row_id = await pool.fetchval(
_INSERT_PROJECTION,
trend_window_id,
projection.projected_direction,
projection.projected_strength,
projection.projected_confidence,
projection.projection_horizon,
json.dumps(projection.driving_factors),
projection.macro_contribution_pct,
projection.diverges_from_current,
projection.computed_at,
)
logger.info(
"Persisted trend projection for window=%s: direction=%s strength=%.3f confidence=%.3f diverges=%s",
trend_window_id,
projection.projected_direction,
projection.projected_strength,
projection.projected_confidence,
projection.diverges_from_current,
)
return str(row_id)
+226 -13
View File
@@ -4,13 +4,13 @@ Aggregates company-level trend summaries into sector and market-level
summaries, enabling top-down views of sentiment and risk across the
portfolio.
Requirements: 6.3, 6.4, 6.5
Requirements: 6.1, 6.2, 6.3, 6.4, 6.5
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
import asyncpg
@@ -42,6 +42,126 @@ class CompanyTrendRow:
top_opposing_evidence: list[str]
@dataclass
class SectorMacroImpact:
"""Aggregated macro impact data for a single sector.
Used to incorporate macro signals into sector and market rollups.
Requirements: 6.1, 6.2, 6.3
"""
sector: str
total_impact: float # sum of macro_impact_score across companies in sector
avg_impact: float # average macro_impact_score
company_count: int # number of companies affected
net_direction: float # weighted direction: +1 positive, -1 negative, 0 mixed
event_ids: list[str] = field(default_factory=list) # contributing event IDs
# Threshold for disproportionate sector impact (Requirement 6.3)
SECTOR_CONCENTRATION_THRESHOLD = 0.60
# ---------------------------------------------------------------------------
# Fetch sector-level macro impact aggregates
# ---------------------------------------------------------------------------
_SECTOR_MACRO_IMPACT_QUERY = """
SELECT
c.sector,
mir.event_id,
mir.macro_impact_score,
mir.impact_direction
FROM macro_impact_records mir
JOIN companies c ON c.id = mir.company_id AND c.active = TRUE
WHERE mir.computed_at >= $1
AND mir.computed_at <= $2
ORDER BY c.sector, mir.macro_impact_score DESC
"""
async def fetch_sector_macro_impacts(
pool: asyncpg.Pool,
window_start: datetime,
window_end: datetime,
) -> dict[str, SectorMacroImpact]:
"""Fetch macro impact records aggregated by sector for a time range.
Returns a mapping of sector name to SectorMacroImpact.
"""
rows = await pool.fetch(_SECTOR_MACRO_IMPACT_QUERY, window_start, window_end)
# Accumulate per-sector
sector_data: dict[str, dict] = {}
direction_map = {"positive": 1.0, "negative": -1.0, "mixed": 0.0, "neutral": 0.0}
for row in rows:
sector = str(row["sector"]) if row["sector"] else "Unknown"
score = float(row["macro_impact_score"] or 0.0)
direction = row["impact_direction"] or "neutral"
event_id = str(row["event_id"])
if sector not in sector_data:
sector_data[sector] = {
"total": 0.0,
"count": 0,
"dir_sum": 0.0,
"dir_count": 0,
"event_ids": set(),
}
d = sector_data[sector]
d["total"] += score
d["count"] += 1
dir_val = direction_map.get(direction, 0.0)
if dir_val != 0.0:
d["dir_sum"] += dir_val
d["dir_count"] += 1
d["event_ids"].add(event_id)
result: dict[str, SectorMacroImpact] = {}
for sector, d in sector_data.items():
count = d["count"]
avg = d["total"] / count if count > 0 else 0.0
net_dir = d["dir_sum"] / d["dir_count"] if d["dir_count"] > 0 else 0.0
result[sector] = SectorMacroImpact(
sector=sector,
total_impact=d["total"],
avg_impact=avg,
company_count=count,
net_direction=net_dir,
event_ids=sorted(d["event_ids"]),
)
return result
# ---------------------------------------------------------------------------
# Sector macro concentration helper (Requirement 6.3)
# ---------------------------------------------------------------------------
def compute_sector_macro_concentration(
sector_impacts: dict[str, SectorMacroImpact],
) -> list[tuple[str, float]]:
"""Compute the fraction of total macro impact concentrated in each sector.
Returns a list of (sector, fraction) tuples sorted by fraction descending.
Sectors with fraction > SECTOR_CONCENTRATION_THRESHOLD are considered
disproportionately affected.
"""
total = sum(si.total_impact for si in sector_impacts.values())
if total <= 0.0:
return []
fractions = [
(sector, si.total_impact / total)
for sector, si in sector_impacts.items()
]
fractions.sort(key=lambda x: x[1], reverse=True)
return fractions
# ---------------------------------------------------------------------------
# Fetch latest company trends for a given window
# ---------------------------------------------------------------------------
@@ -141,11 +261,22 @@ def rollup_trends(
entity_id: str,
window: str,
reference_time: datetime,
macro_impacts: dict[str, SectorMacroImpact] | None = None,
) -> TrendSummary:
"""Aggregate a list of company-level trends into a single rollup summary.
Each company trend is weighted by its confidence to produce a
confidence-weighted average of direction, strength, and contradiction.
When macro_impacts is provided:
- For sector rollups: incorporates the sector's macro signal into
strength and confidence, weighted by constituent company exposure.
- For market rollups: aggregates macro signals across all sectors and
surfaces disproportionately affected sectors (>60% concentration)
in material_risks or dominant_catalysts.
When macro_impacts is None or empty, produces identical output to
the original company-only rollup.
"""
if not trends:
return TrendSummary(
@@ -204,16 +335,70 @@ def rollup_trends(
avg_contradiction = weighted_contradiction / total_weight
avg_confidence = total_weight / len(trends)
# --- Incorporate macro impact signals when available ---
macro_strength_adj = 0.0
macro_confidence_adj = 0.0
macro_catalysts: list[str] = []
macro_risks: list[str] = []
if macro_impacts:
if entity_type == "sector":
# Sector rollup: incorporate this sector's macro signal
sector_macro = macro_impacts.get(entity_id)
if sector_macro and sector_macro.total_impact > 0:
# Weight macro contribution by avg impact and company breadth
breadth = min(sector_macro.company_count / max(len(trends), 1), 1.0)
macro_strength_adj = sector_macro.avg_impact * breadth * 0.3
macro_confidence_adj = sector_macro.avg_impact * breadth * 0.1
# Nudge direction based on macro net direction
avg_direction += sector_macro.net_direction * macro_strength_adj * 0.5
elif entity_type == "market":
# Market rollup: aggregate macro signals across all sectors
total_macro = sum(si.total_impact for si in macro_impacts.values())
if total_macro > 0:
total_companies = sum(si.company_count for si in macro_impacts.values())
breadth = min(total_companies / max(len(trends), 1), 1.0)
avg_macro = total_macro / max(len(macro_impacts), 1)
macro_strength_adj = avg_macro * breadth * 0.3
macro_confidence_adj = avg_macro * breadth * 0.1
# Aggregate net direction across sectors
dir_sum = sum(
si.net_direction * si.total_impact
for si in macro_impacts.values()
)
net_dir = dir_sum / total_macro if total_macro > 0 else 0.0
avg_direction += net_dir * macro_strength_adj * 0.5
# Surface disproportionately affected sectors (Requirement 6.3)
concentration = compute_sector_macro_concentration(macro_impacts)
for sector, fraction in concentration:
if fraction > SECTOR_CONCENTRATION_THRESHOLD:
si = macro_impacts[sector]
label = f"Macro: {sector} ({fraction:.0%} of macro impact)"
if si.net_direction < 0:
macro_risks.append(label)
else:
macro_catalysts.append(label)
# Apply macro adjustments to strength and confidence
adj_strength = avg_strength + macro_strength_adj
adj_confidence = avg_confidence + macro_confidence_adj
# Derive direction
direction = _derive_rollup_direction(avg_direction, avg_contradiction)
# Top catalysts
# Top catalysts (macro catalysts prepended when present)
sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True)
catalysts = [c for c, _ in sorted_catalysts[:5]]
catalysts = macro_catalysts + [c for c, _ in sorted_catalysts[:5]]
catalysts = catalysts[:5]
# Top risks (deduplicated, by weight)
# Top risks (macro risks prepended when present, deduplicated)
sorted_risks = sorted(risk_set.items(), key=lambda x: x[1], reverse=True)
risks = [r for r, _ in sorted_risks[:5]]
base_risks = [r for r, _ in sorted_risks[:5]]
risks = macro_risks + base_risks
risks = risks[:5]
# Disagreement details
disagreement = _build_rollup_disagreement(trends, entity_id)
@@ -223,8 +408,8 @@ def rollup_trends(
entity_id=entity_id,
window=TrendWindow(window),
trend_direction=direction,
trend_strength=round(min(abs(avg_strength), 1.0), 4),
confidence=round(max(0.0, min(avg_confidence, 1.0)), 4),
trend_strength=round(min(abs(adj_strength), 1.0), 4),
confidence=round(max(0.0, min(adj_confidence, 1.0)), 4),
top_supporting_evidence=list(dict.fromkeys(all_supporting))[:10],
top_opposing_evidence=list(dict.fromkeys(all_opposing))[:10],
dominant_catalysts=catalysts,
@@ -341,11 +526,14 @@ async def aggregate_sector(
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
macro_impacts: dict[str, SectorMacroImpact] | None = None,
) -> TrendSummary:
"""Compute and persist a sector-level rollup for one window.
Fetches the latest company trends, filters to the given sector,
and rolls them up into a single sector summary.
and rolls them up into a single sector summary. When macro_impacts
is provided, incorporates macro signals weighted by constituent
company exposure.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
@@ -355,7 +543,14 @@ async def aggregate_sector(
all_trends = await fetch_latest_company_trends(pool, window, since)
sector_trends = [t for t in all_trends if t.sector == sector]
summary = rollup_trends(sector_trends, "sector", sector, window, reference_time)
# Fetch macro impacts if not provided
if macro_impacts is None:
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
summary = rollup_trends(
sector_trends, "sector", sector, window, reference_time,
macro_impacts=macro_impacts,
)
if sector_trends:
rollup_id = await persist_rollup(pool, summary)
@@ -373,10 +568,13 @@ async def aggregate_market(
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
macro_impacts: dict[str, SectorMacroImpact] | None = None,
) -> TrendSummary:
"""Compute and persist a market-wide rollup for one window.
Aggregates all company trends regardless of sector.
Aggregates all company trends regardless of sector. When macro_impacts
is provided, aggregates macro signals across all sectors and surfaces
disproportionately affected sectors in material_risks or dominant_catalysts.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
@@ -385,7 +583,14 @@ async def aggregate_market(
all_trends = await fetch_latest_company_trends(pool, window, since)
summary = rollup_trends(all_trends, "market", "all", window, reference_time)
# Fetch macro impacts if not provided
if macro_impacts is None:
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
summary = rollup_trends(
all_trends, "market", "all", window, reference_time,
macro_impacts=macro_impacts,
)
if all_trends:
rollup_id = await persist_rollup(pool, summary)
@@ -403,6 +608,7 @@ async def aggregate_all_sectors(
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
macro_impacts: dict[str, SectorMacroImpact] | None = None,
) -> list[TrendSummary]:
"""Compute sector rollups for every sector that has company trends."""
if reference_time is None:
@@ -412,6 +618,10 @@ async def aggregate_all_sectors(
all_trends = await fetch_latest_company_trends(pool, window, since)
# Fetch macro impacts once for all sectors if not provided
if macro_impacts is None:
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
# Group by sector
sectors: dict[str, list[CompanyTrendRow]] = {}
for t in all_trends:
@@ -419,7 +629,10 @@ async def aggregate_all_sectors(
summaries: list[TrendSummary] = []
for sector, trends in sectors.items():
summary = rollup_trends(trends, "sector", sector, window, reference_time)
summary = rollup_trends(
trends, "sector", sector, window, reference_time,
macro_impacts=macro_impacts,
)
if trends:
_id = await persist_rollup(pool, summary)
summaries.append(summary)
+306
View File
@@ -0,0 +1,306 @@
"""Competitive signal propagation engine.
Evaluates incoming document intelligence, identifies competitors via
the competitor_relationships table, queries historical cross-company
patterns, and produces weighted competitive signals persisted to
competitive_signal_records.
Also converts pattern and competitive signals into WeightedSignal
objects for the aggregation engine.
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 9.1
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Optional
import asyncpg
from services.aggregation.pattern_matcher import (
HistoricalPattern,
find_cross_company_patterns,
)
from services.aggregation.scoring import (
ScoringConfig,
WeightedSignal,
compute_signal_weight,
)
from services.shared.config import CompetitiveConfig
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass
class CompetitiveSignalRecord:
"""A competitive signal produced by propagating a source event to a
competitor based on historical cross-company patterns."""
source_document_id: str
source_ticker: str
target_ticker: str
catalyst_type: str
pattern_confidence: float
signal_direction: str # bullish | bearish
signal_strength: float # [0, 1]
relationship_strength: float
computed_at: datetime
# ---------------------------------------------------------------------------
# SQL queries
# ---------------------------------------------------------------------------
_COMPETITOR_LOOKUP_QUERY = """
SELECT company_a_id, company_b_id, strength
FROM competitor_relationships
WHERE (company_a_id = $1 OR company_b_id = $1)
AND active = TRUE
"""
_INSERT_SIGNAL_QUERY = """
INSERT INTO competitive_signal_records
(source_document_id, source_ticker, target_ticker, catalyst_type,
pattern_confidence, signal_direction, signal_strength,
relationship_strength, computed_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
"""
# ---------------------------------------------------------------------------
# propagate_signals
# ---------------------------------------------------------------------------
async def propagate_signals(
pool: asyncpg.Pool,
ticker: str,
catalyst_type: str,
impact_score: float,
document_id: str,
config: Optional[CompetitiveConfig] = None,
) -> list[CompetitiveSignalRecord]:
"""Look up competitors, query cross-company patterns, produce weighted
competitive signals, and persist them.
Args:
pool: asyncpg connection pool.
ticker: Source company ticker that received the catalyst.
catalyst_type: The catalyst type from document intelligence.
impact_score: The source document's impact score.
document_id: The source document ID.
config: Optional competitive config overrides.
Returns:
List of CompetitiveSignalRecord objects produced and persisted.
"""
cfg = config or CompetitiveConfig()
now = datetime.now(timezone.utc)
records: list[CompetitiveSignalRecord] = []
# Step 1: Look up active competitors
try:
async with pool.acquire() as conn:
rows = await conn.fetch(_COMPETITOR_LOOKUP_QUERY, ticker)
except Exception:
logger.exception("Failed to look up competitors for %s", ticker)
return records
if not rows:
logger.debug("No active competitors found for %s", ticker)
return records
# Step 2: For each competitor, query cross-company patterns
for row in rows:
company_a = str(row["company_a_id"])
company_b = str(row["company_b_id"])
rel_strength = float(row["strength"])
# Determine the competitor ticker (the other side of the relationship)
competitor_ticker = company_b if company_a == ticker else company_a
# Threshold gating (Req 4.5)
if rel_strength < cfg.propagation_strength_threshold:
logger.info(
"Skipping propagation %s%s: relationship strength %.3f "
"below threshold %.3f",
ticker, competitor_ticker, rel_strength,
cfg.propagation_strength_threshold,
)
continue
# Query cross-company patterns
try:
patterns = await find_cross_company_patterns(
pool, ticker, competitor_ticker, catalyst_type, config=cfg,
)
except Exception:
logger.exception(
"Failed to query cross-company patterns for %s%s/%s",
ticker, competitor_ticker, catalyst_type,
)
continue
for pattern in patterns:
# Confidence threshold gating (Req 9.1)
if pattern.pattern_confidence < cfg.pattern_confidence_threshold:
logger.info(
"Excluding pattern %s%s/%s/%s: confidence %.3f "
"below threshold %.3f",
ticker, competitor_ticker, catalyst_type,
pattern.time_horizon, pattern.pattern_confidence,
cfg.pattern_confidence_threshold,
)
continue
# Compute signal strength (Req 4.3)
raw_strength = (
pattern.avg_strength
* rel_strength
* pattern.pattern_confidence
* impact_score
)
signal_strength = min(max(raw_strength, 0.0), 1.0)
# Determine direction
direction = (
"bullish" if pattern.bullish_pct > pattern.bearish_pct
else "bearish"
)
record = CompetitiveSignalRecord(
source_document_id=document_id,
source_ticker=ticker,
target_ticker=competitor_ticker,
catalyst_type=catalyst_type,
pattern_confidence=pattern.pattern_confidence,
signal_direction=direction,
signal_strength=signal_strength,
relationship_strength=rel_strength,
computed_at=now,
)
records.append(record)
# Step 3: Persist all records
if records:
try:
async with pool.acquire() as conn:
await conn.executemany(
_INSERT_SIGNAL_QUERY,
[
(
r.source_document_id,
r.source_ticker,
r.target_ticker,
r.catalyst_type,
r.pattern_confidence,
r.signal_direction,
r.signal_strength,
r.relationship_strength,
r.computed_at,
)
for r in records
],
)
except Exception:
logger.exception(
"Failed to persist %d competitive signal records", len(records),
)
return records
# ---------------------------------------------------------------------------
# build_pattern_weighted_signals
# ---------------------------------------------------------------------------
def build_pattern_weighted_signals(
patterns: list[HistoricalPattern],
competitive_signals: list[CompetitiveSignalRecord],
reference_time: datetime,
window: str,
config: Optional[CompetitiveConfig] = None,
) -> list[WeightedSignal]:
"""Convert pattern and competitive signal objects to WeightedSignal
objects for the aggregation engine.
For HistoricalPattern objects:
- sentiment_value = +1.0 if bullish_pct > bearish_pct else -1.0
- impact_score = avg_strength * competitive_signal_weight
- published_at = data_end (most recent data point for recency decay)
- extraction_confidence = pattern_confidence
For CompetitiveSignalRecord objects:
- sentiment_value = +1.0 if direction == "bullish" else -1.0
- impact_score = signal_strength * competitive_signal_weight
- published_at = computed_at (for recency decay)
- extraction_confidence = pattern_confidence
Args:
patterns: Self-company historical patterns.
competitive_signals: Competitive signal records from propagation.
reference_time: Aggregation anchor time for recency decay.
window: Trend window identifier (e.g. "7d").
config: Optional competitive config overrides.
Returns:
List of WeightedSignal objects ready for aggregation.
"""
cfg = config or CompetitiveConfig()
scoring_cfg = ScoringConfig()
signals: list[WeightedSignal] = []
# Convert HistoricalPattern objects
for pattern in patterns:
sentiment_value = (
1.0 if pattern.bullish_pct > pattern.bearish_pct else -1.0
)
impact = pattern.avg_strength * cfg.competitive_signal_weight
weight = compute_signal_weight(
published_at=pattern.data_end,
reference_time=reference_time,
window=window,
source_credibility=1.0, # patterns are derived from validated data
novelty_score=0.5,
extraction_confidence=pattern.pattern_confidence,
market_ctx=None,
config=scoring_cfg,
)
signals.append(WeightedSignal(
document_id=f"pattern:{pattern.source_ticker}:{pattern.catalyst_type}:{pattern.time_horizon}",
weight=weight,
sentiment_value=sentiment_value,
impact_score=impact,
))
# Convert CompetitiveSignalRecord objects
for sig in competitive_signals:
sentiment_value = 1.0 if sig.signal_direction == "bullish" else -1.0
impact = sig.signal_strength * cfg.competitive_signal_weight
weight = compute_signal_weight(
published_at=sig.computed_at,
reference_time=reference_time,
window=window,
source_credibility=1.0,
novelty_score=0.5,
extraction_confidence=sig.pattern_confidence,
market_ctx=None,
config=scoring_cfg,
)
signals.append(WeightedSignal(
document_id=sig.source_document_id,
weight=weight,
sentiment_value=sentiment_value,
impact_score=impact,
))
return signals
+410 -5
View File
@@ -40,6 +40,17 @@ from services.shared.metrics import (
AGGREGATION_SIGNALS_PROCESSED,
AGGREGATION_WINDOWS_COMPUTED,
)
from services.aggregation.pattern_matcher import find_self_patterns
from services.aggregation.projection import (
MacroEventInfo,
TrendProjection,
compute_projection,
persist_trend_projection,
)
from services.aggregation.signal_propagation import (
CompetitiveSignalRecord,
build_pattern_weighted_signals,
)
from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow
logger = logging.getLogger(__name__)
@@ -64,6 +75,10 @@ class AggregationConfig:
windows: list[str] | None = None # None = all windows
scoring: ScoringConfig | None = None
max_evidence: int = MAX_EVIDENCE_REFS
macro_signal_weight: float = 0.3 # relative weight of macro vs company signals
macro_enabled: bool = True # runtime toggle state
competitive_signal_weight: float = 0.2 # relative weight of pattern signals
competitive_enabled: bool = True # runtime toggle state
def effective_windows(self) -> list[str]:
if self.windows:
@@ -154,6 +169,236 @@ async def fetch_impact_records(
# ---------------------------------------------------------------------------
# Fetch macro toggle state from risk_configs
#
# MACRO LAYER TOGGLE BEHAVIOR (Requirements 11.2, 11.3, 11.4):
# - The toggle state is read fresh from PostgreSQL at the start of each
# aggregation cycle (no caching), so changes take effect immediately on
# the next cycle.
# - When disabled: ingestion and classification continue normally (historical
# data is preserved), but interpolation and aggregation integration are
# skipped — the aggregation engine produces trends using only company-
# specific signals.
# - When re-enabled: the engine resumes computing macro impact scores using
# the most recent GlobalEvent classifications, including any events that
# were ingested and classified while the layer was disabled.
# ---------------------------------------------------------------------------
_MACRO_TOGGLE_QUERY = """
SELECT config->>'macro_enabled' AS macro_enabled
FROM risk_configs
WHERE active = TRUE
ORDER BY updated_at DESC
LIMIT 1
"""
async def fetch_macro_enabled(pool: asyncpg.Pool) -> bool | None:
"""Check macro toggle state from risk_configs table.
Returns True/False if explicitly set, or None if no config exists
(caller should fall back to AggregationConfig default).
"""
row = await pool.fetchrow(_MACRO_TOGGLE_QUERY)
if row is None or row["macro_enabled"] is None:
return None
return row["macro_enabled"].lower() == "true"
# ---------------------------------------------------------------------------
# Fetch competitive toggle state from risk_configs
# ---------------------------------------------------------------------------
_COMPETITIVE_TOGGLE_QUERY = """
SELECT config->>'competitive_enabled' AS competitive_enabled
FROM risk_configs
WHERE active = TRUE
ORDER BY updated_at DESC
LIMIT 1
"""
async def fetch_competitive_enabled(pool: asyncpg.Pool) -> bool | None:
"""Check competitive toggle state from risk_configs table.
Returns True/False if explicitly set, or None if no config exists
(caller should fall back to AggregationConfig default).
"""
row = await pool.fetchrow(_COMPETITIVE_TOGGLE_QUERY)
if row is None or row["competitive_enabled"] is None:
return None
return row["competitive_enabled"].lower() == "true"
# ---------------------------------------------------------------------------
# Fetch competitive signals targeting a ticker within a time window
# ---------------------------------------------------------------------------
_COMPETITIVE_SIGNALS_QUERY = """
SELECT source_document_id, source_ticker, target_ticker, catalyst_type,
pattern_confidence, signal_direction, signal_strength,
relationship_strength, computed_at
FROM competitive_signal_records
WHERE target_ticker = $1
AND computed_at >= $2
AND computed_at <= $3
ORDER BY computed_at DESC
"""
async def fetch_competitive_signals(
pool: asyncpg.Pool,
ticker: str,
window_start: datetime,
window_end: datetime,
) -> list[CompetitiveSignalRecord]:
"""Fetch competitive signal records targeting a ticker in a time range."""
rows = await pool.fetch(
_COMPETITIVE_SIGNALS_QUERY, ticker, window_start, window_end,
)
return [
CompetitiveSignalRecord(
source_document_id=str(row["source_document_id"]),
source_ticker=row["source_ticker"],
target_ticker=row["target_ticker"],
catalyst_type=row["catalyst_type"],
pattern_confidence=float(row["pattern_confidence"]),
signal_direction=row["signal_direction"],
signal_strength=float(row["signal_strength"]),
relationship_strength=float(row["relationship_strength"]),
computed_at=row["computed_at"],
)
for row in rows
]
# ---------------------------------------------------------------------------
# Fetch macro impact records for a ticker within a time window
# ---------------------------------------------------------------------------
_MACRO_IMPACT_QUERY = """
SELECT
mir.event_id,
mir.company_id,
mir.ticker,
mir.macro_impact_score,
mir.impact_direction,
mir.contributing_factors,
mir.confidence,
mir.computed_at,
ge.source_document_id,
d.published_at AS event_published_at
FROM macro_impact_records mir
JOIN global_events ge ON ge.id = mir.event_id
JOIN documents d ON d.id = ge.source_document_id
WHERE mir.ticker = $1
AND mir.computed_at >= $2
AND mir.computed_at <= $3
ORDER BY mir.computed_at DESC
"""
@dataclass
class MacroImpactRow:
"""Parsed row from the macro impact query."""
event_id: str
company_id: str
ticker: str
macro_impact_score: float
impact_direction: str
contributing_factors: list[str]
confidence: float
computed_at: datetime
source_document_id: str
event_published_at: datetime
def _parse_macro_impact_row(row: Any) -> MacroImpactRow:
"""Convert an asyncpg Record to a MacroImpactRow."""
factors = row["contributing_factors"]
if isinstance(factors, str):
factors = json.loads(factors)
return MacroImpactRow(
event_id=str(row["event_id"]),
company_id=str(row["company_id"]),
ticker=row["ticker"],
macro_impact_score=float(row["macro_impact_score"] or 0.0),
impact_direction=row["impact_direction"] or "neutral",
contributing_factors=factors if isinstance(factors, list) else [],
confidence=float(row["confidence"] or 0.5),
computed_at=row["computed_at"],
source_document_id=str(row["source_document_id"]),
event_published_at=row["event_published_at"],
)
async def fetch_macro_impact_records(
pool: asyncpg.Pool,
ticker: str,
window_start: datetime,
window_end: datetime,
) -> list[MacroImpactRow]:
"""Fetch macro impact records for a ticker in a time range."""
rows = await pool.fetch(_MACRO_IMPACT_QUERY, ticker, window_start, window_end)
return [_parse_macro_impact_row(r) for r in rows]
# ---------------------------------------------------------------------------
# Convert macro impact records to WeightedSignals
# ---------------------------------------------------------------------------
_DIRECTION_TO_SENTIMENT: dict[str, float] = {
"positive": 1.0,
"negative": -1.0,
"mixed": 0.0,
"neutral": 0.0,
}
def build_macro_weighted_signals(
macro_impacts: list[MacroImpactRow],
reference_time: datetime,
window: str,
macro_signal_weight: float = 0.3,
config: ScoringConfig | None = None,
) -> list[WeightedSignal]:
"""Convert macro impact records into WeightedSignal objects.
Uses the same scoring pipeline as company signals:
- document_id = source_document_id (for evidence tracing)
- sentiment_value mapped from impact_direction
- impact_score = macro_impact_score * macro_signal_weight
- recency decay from the global event's publication time
- confidence gating from the macro record's confidence
"""
cfg = config or ScoringConfig()
signals: list[WeightedSignal] = []
for mir in macro_impacts:
sw = compute_signal_weight(
published_at=mir.event_published_at,
reference_time=reference_time,
window=window,
source_credibility=mir.confidence,
novelty_score=0.5,
extraction_confidence=mir.confidence,
config=cfg,
)
sentiment = _DIRECTION_TO_SENTIMENT.get(mir.impact_direction, 0.0)
impact = mir.macro_impact_score * macro_signal_weight
signals.append(
WeightedSignal(
document_id=mir.source_document_id,
weight=sw,
sentiment_value=sentiment,
impact_score=impact,
)
)
return signals
# ---------------------------------------------------------------------------
# Build weighted signals from impact records
# ---------------------------------------------------------------------------
@@ -544,6 +789,61 @@ async def persist_trend_evidence(
return len(rows)
# ---------------------------------------------------------------------------
# Build MacroEventInfo objects for projection computation
# ---------------------------------------------------------------------------
_MACRO_EVENT_INFO_QUERY = """
SELECT
mir.event_id,
mir.macro_impact_score,
mir.impact_direction,
mir.confidence,
ge.estimated_duration,
ge.severity,
d.published_at AS event_published_at
FROM macro_impact_records mir
JOIN global_events ge ON ge.id = mir.event_id
JOIN documents d ON d.id = ge.source_document_id
WHERE mir.ticker = $1
AND mir.computed_at >= $2
AND mir.computed_at <= $3
ORDER BY mir.computed_at DESC
"""
async def _build_macro_event_infos(
pool: asyncpg.Pool,
ticker: str,
window_start: datetime,
reference_time: datetime,
) -> list[MacroEventInfo]:
"""Fetch macro impact records and build MacroEventInfo objects for projection."""
rows = await pool.fetch(
_MACRO_EVENT_INFO_QUERY, ticker, window_start, reference_time,
)
infos: list[MacroEventInfo] = []
for row in rows:
published_at = row["event_published_at"]
age_hours = 0.0
if published_at:
age_hours = max(
(reference_time - published_at).total_seconds() / 3600.0, 0.0,
)
infos.append(
MacroEventInfo(
event_id=str(row["event_id"]),
macro_impact_score=float(row["macro_impact_score"] or 0.0),
impact_direction=row["impact_direction"] or "neutral",
confidence=float(row["confidence"] or 0.5),
estimated_duration=row["estimated_duration"] or "short_term",
severity=row["severity"] or "low",
event_age_hours=age_hours,
)
)
return infos
# ---------------------------------------------------------------------------
# Main aggregation entry point for a single ticker + window
# ---------------------------------------------------------------------------
@@ -563,8 +863,10 @@ async def aggregate_company_window(
2. Fetch document impact records from PostgreSQL.
3. Fetch market context for the ticker.
4. Build weighted signals using the scoring module.
5. Assemble the TrendSummary.
6. Persist to trend_windows table.
5. Check macro toggle and fetch/merge macro signals if enabled.
6. Check competitive toggle and fetch/merge pattern/competitive signals if enabled.
7. Assemble the TrendSummary.
8. Persist to trend_windows table.
Returns the assembled TrendSummary.
"""
@@ -589,7 +891,83 @@ async def aggregate_company_window(
impacts, reference_time, window, market_ctx, scoring_cfg,
)
# 4. Assemble trend summary with evidence details
# 4. Check macro toggle and merge macro signals
# (Requirement 11.2, 11.3, 11.4): Toggle state is read from the DB on
# every aggregation cycle. When disabled, macro signals are skipped but
# ingestion/classification continue independently — so when re-enabled,
# the most recent classifications (including those ingested while disabled)
# are immediately available for impact computation.
macro_enabled = cfg.macro_enabled
db_toggle = await fetch_macro_enabled(pool)
if db_toggle is not None:
macro_enabled = db_toggle
if macro_enabled:
macro_impacts = await fetch_macro_impact_records(
pool, ticker, window_start, reference_time,
)
if macro_impacts:
macro_signals = build_macro_weighted_signals(
macro_impacts,
reference_time,
window,
macro_signal_weight=cfg.macro_signal_weight,
config=scoring_cfg,
)
signals = signals + macro_signals
logger.info(
"Merged %d macro signals for %s/%s",
len(macro_signals), ticker, window,
)
# 5. Check competitive toggle and merge pattern/competitive signals
# (Requirements 5.1-5.6): Same toggle pattern as macro layer. When
# disabled, pattern mining remains queryable but aggregation skips
# competitive signals — no degradation of existing behavior.
competitive_enabled = cfg.competitive_enabled
db_competitive_toggle = await fetch_competitive_enabled(pool)
if db_competitive_toggle is not None:
competitive_enabled = db_competitive_toggle
if competitive_enabled:
try:
# Get unique catalyst types from the impact records
catalyst_types = {imp.catalyst_type for imp in impacts}
# Query self-company historical patterns for each catalyst type
all_patterns = []
for cat_type in catalyst_types:
patterns = await find_self_patterns(pool, ticker, cat_type)
all_patterns.extend(patterns)
# Fetch competitive signals targeting this ticker
comp_signals = await fetch_competitive_signals(
pool, ticker, window_start, reference_time,
)
# Convert to WeightedSignal objects
if all_patterns or comp_signals:
pattern_weighted = build_pattern_weighted_signals(
patterns=all_patterns,
competitive_signals=comp_signals,
reference_time=reference_time,
window=window,
)
signals = signals + pattern_weighted
logger.info(
"Merged %d pattern/competitive signals for %s/%s "
"(patterns=%d, competitive=%d)",
len(pattern_weighted), ticker, window,
len(all_patterns), len(comp_signals),
)
except Exception:
logger.exception(
"Failed to fetch pattern/competitive signals for %s/%s"
"continuing with company+macro signals only",
ticker, window,
)
# 6. Assemble trend summary with evidence details
assembled = assemble_trend_with_evidence(
ticker=ticker,
window=window,
@@ -601,10 +979,10 @@ async def aggregate_company_window(
)
summary = assembled.summary
# 5. Persist trend window
# 7. Persist trend window
trend_id = await persist_trend_summary(pool, summary)
# 6. Persist evidence mappings
# 8. Persist evidence mappings
evidence_count = await persist_trend_evidence(
pool, trend_id,
assembled.supporting_evidence,
@@ -617,6 +995,33 @@ async def aggregate_company_window(
summary.trend_strength, summary.confidence, len(signals), evidence_count,
)
# 9. Compute and persist trend projection
try:
macro_event_infos: list[MacroEventInfo] = []
if macro_enabled:
macro_event_infos = await _build_macro_event_infos(
pool, ticker, window_start, reference_time,
)
projection = compute_projection(
summary=summary,
macro_events=macro_event_infos if macro_event_infos else None,
macro_enabled=macro_enabled,
upcoming_catalysts=summary.dominant_catalysts[:3] if summary.dominant_catalysts else None,
)
await persist_trend_projection(pool, trend_id, projection)
logger.info(
"Persisted projection for %s/%s: direction=%s strength=%.3f confidence=%.3f diverges=%s",
ticker, window, projection.projected_direction,
projection.projected_strength, projection.projected_confidence,
projection.diverges_from_current,
)
except Exception:
logger.exception(
"Failed to compute/persist projection for trend %s (%s/%s) — continuing",
trend_id, ticker, window,
)
# Prometheus metrics
AGGREGATION_WINDOWS_COMPUTED.labels(window=window).inc()
AGGREGATION_SIGNALS_PROCESSED.labels(window=window).inc(len(signals))