feat: competitive intelligence & historical pattern matching layer

This commit is contained in:
Celes Renata
2026-04-14 19:42:48 +00:00
parent b478022ba3
commit f7a11d14ea
203 changed files with 20155 additions and 97 deletions
+136
View File
@@ -0,0 +1,136 @@
"""Macro news adapter for global/geopolitical news ingestion.
Fetches macro-level news articles from configured sources for global event
classification. Reuses the same adapter pattern as company-specific news
but targets macro-focused endpoints and does not require a ticker.
Requirements: 1.1, 1.2, 1.3, 1.4
"""
import hashlib
import logging
import time
from datetime import datetime, timezone
from typing import Any
import httpx
from .base import AdapterResult, BaseAdapter
logger = logging.getLogger("macro_news_adapter")
class MacroNewsAdapter(BaseAdapter):
"""Adapter for fetching macro/geopolitical news from configured sources.
Supports fetching from any HTTP endpoint that returns JSON with a list
of news articles. The endpoint URL and response parsing are configured
via the source config dict.
Config options:
url: The endpoint URL to fetch from
limit: Max articles to return per request (default 20)
params: Additional query parameters as a dict
results_key: JSON key containing the article list (default "results")
"""
def __init__(self, api_key: str = "", base_url: str = "") -> None:
self.api_key = api_key
self.base_url = base_url.rstrip("/") if base_url else ""
def source_type(self) -> str:
return "macro_news"
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
"""Fetch macro news articles from the configured endpoint.
The ticker parameter is ignored for macro sources — these are
global/geopolitical news, not company-specific.
Args:
ticker: Ignored for macro sources (may be empty string).
config: Source-specific configuration with url, params, etc.
Returns:
AdapterResult with raw payload and parsed article items.
"""
url = config.get("url", "")
if not url and self.base_url:
url = self.base_url
if not url:
return self._error_result("No URL configured for macro news source")
params = dict(config.get("params", {}))
if self.api_key:
params["apiKey"] = self.api_key
limit = config.get("limit", 20)
params["limit"] = str(min(int(limit), 1000))
async with httpx.AsyncClient(timeout=30) as client:
t0 = time.monotonic()
try:
resp = await client.get(url, params=params)
elapsed_ms = (time.monotonic() - t0) * 1000
resp.raise_for_status()
raw = resp.content
data = resp.json()
content_hash = hashlib.sha256(raw).hexdigest()
results_key = config.get("results_key", "results")
items = data.get(results_key, [])
if not isinstance(items, list):
items = []
return AdapterResult(
source_type="macro_news",
ticker="",
items=items,
raw_payload=raw,
content_hash=content_hash,
fetched_at=datetime.now(timezone.utc),
http_status=resp.status_code,
response_time_ms=round(elapsed_ms, 1),
metadata={
"provider": config.get("provider", "macro"),
"results_count": len(items),
},
)
except httpx.HTTPStatusError as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Macro news HTTP error: %s", e)
return self._error_result(
str(e), elapsed_ms,
http_status=e.response.status_code if e.response else None,
raw=e.response.content if e.response else b"",
)
except httpx.TimeoutException as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Macro news timeout: %s", e)
return self._error_result(f"timeout: {e}", elapsed_ms)
except Exception as e:
elapsed_ms = (time.monotonic() - t0) * 1000
logger.error("Macro news fetch failed: %s", e)
return self._error_result(str(e), elapsed_ms)
def _error_result(
self,
error: str,
elapsed_ms: float = 0.0,
http_status: int | None = None,
raw: bytes = b"",
) -> AdapterResult:
"""Build an error AdapterResult for macro news fetches."""
return AdapterResult(
source_type="macro_news",
ticker="",
items=[],
raw_payload=raw,
content_hash="",
fetched_at=datetime.now(timezone.utc),
error=error,
http_status=http_status,
response_time_ms=round(elapsed_ms, 1),
metadata={"provider": "macro"},
)
+741
View File
@@ -0,0 +1,741 @@
"""Interpolation engine — macro-to-company impact scoring.
Computes per-company macro impact scores by evaluating overlap between
global event classifications and company exposure profiles. Produces
MacroImpactRecord objects that feed into the aggregation engine as
additional weighted signals.
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 3.2
"""
from __future__ import annotations
import json
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.extractor.event_classifier import GlobalEvent
from services.shared.schemas import (
ExposureProfileSchema,
MarketPositionTier,
SeverityLevel,
)
logger = logging.getLogger("interpolation")
# ---------------------------------------------------------------------------
# Default configuration constants
# ---------------------------------------------------------------------------
DEFAULT_CONFIDENCE_THRESHOLD = 0.4
DEFAULT_SHORT_TERM_STALENESS_HOURS = 48
ACCELERATED_DECAY_MULTIPLIER = 0.5 # applied on top of standard recency decay
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# Severity weights
SEVERITY_WEIGHTS: dict[str, float] = {
SeverityLevel.CRITICAL.value: 1.0,
SeverityLevel.HIGH.value: 0.75,
SeverityLevel.MODERATE.value: 0.5,
SeverityLevel.LOW.value: 0.25,
}
# Component weights in the scoring formula
GEO_WEIGHT = 0.35
SUPPLY_WEIGHT = 0.25
COMMODITY_WEIGHT = 0.25
SECTOR_WEIGHT = 0.15
# Resilience modifiers for international events
RESILIENCE_MODIFIERS: dict[str, float] = {
MarketPositionTier.GLOBAL_LEADER.value: 0.7,
MarketPositionTier.MULTINATIONAL.value: 0.85,
MarketPositionTier.REGIONAL.value: 1.0,
MarketPositionTier.DOMESTIC.value: 1.2,
}
# Event types that are typically negative
_NEGATIVE_EVENT_TYPES = frozenset({
"supply_disruption",
"cost_increase",
"regulatory_pressure",
"geopolitical_risk",
"trade_barrier",
})
# Event types that can be positive
_POSITIVE_EVENT_TYPES = frozenset({
"demand_shift",
})
# Event types that can go either way
_AMBIGUOUS_EVENT_TYPES = frozenset({
"commodity_shock",
"currency_impact",
})
# Market cap bucket → market position tier mapping for default profiles
_CAP_TO_TIER: dict[str, str] = {
"large_cap": MarketPositionTier.GLOBAL_LEADER.value,
"mid_cap": MarketPositionTier.MULTINATIONAL.value,
"small_cap": MarketPositionTier.REGIONAL.value,
"micro_cap": MarketPositionTier.DOMESTIC.value,
}
# Sector-based default geographic revenue mixes
_SECTOR_DEFAULT_GEO: dict[str, dict[str, float]] = {
"Information Technology": {"US": 0.45, "CN": 0.15, "EU": 0.15, "JP": 0.10, "KR": 0.15},
"Health Care": {"US": 0.50, "EU": 0.25, "JP": 0.10, "CN": 0.15},
"Financials": {"US": 0.55, "EU": 0.20, "GB": 0.15, "JP": 0.10},
"Energy": {"US": 0.30, "SA": 0.20, "RU": 0.15, "CA": 0.15, "AE": 0.20},
"Materials": {"US": 0.25, "CN": 0.25, "AU": 0.20, "BR": 0.15, "IN": 0.15},
"Industrials": {"US": 0.40, "DE": 0.15, "CN": 0.15, "JP": 0.15, "KR": 0.15},
"Consumer Discretionary": {"US": 0.45, "CN": 0.20, "EU": 0.15, "JP": 0.10, "IN": 0.10},
"Consumer Staples": {"US": 0.45, "EU": 0.20, "CN": 0.15, "IN": 0.10, "BR": 0.10},
"Communication Services": {"US": 0.50, "CN": 0.15, "EU": 0.15, "JP": 0.10, "IN": 0.10},
"Utilities": {"US": 0.70, "EU": 0.15, "JP": 0.15},
"Real Estate": {"US": 0.60, "CN": 0.15, "EU": 0.15, "JP": 0.10},
}
_DEFAULT_GEO = {"US": 0.50, "EU": 0.20, "CN": 0.15, "JP": 0.15}
# ---------------------------------------------------------------------------
# MacroImpactRecord dataclass
# ---------------------------------------------------------------------------
@dataclass
class MacroImpactRecord:
"""A computed macro impact score for a specific company-event pair."""
event_id: str = ""
company_id: str = ""
ticker: str = ""
macro_impact_score: float = 0.0 # [0, 1]
impact_direction: str = "neutral" # positive|negative|mixed
contributing_factors: list[str] = field(default_factory=list)
confidence: float = 0.5 # [0, 1]
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
# ---------------------------------------------------------------------------
# Overlap computation functions
# ---------------------------------------------------------------------------
def compute_geographic_overlap(
event_regions: list[str],
revenue_mix: dict[str, float],
) -> float:
"""Compute geographic overlap using revenue percentage weighting.
For each event region that appears in the company's revenue mix,
sum the revenue percentage. Returns a value in [0, 1].
Args:
event_regions: Region codes from the global event.
revenue_mix: Company's geographic_revenue_mix (region -> pct).
Returns:
Sum of revenue percentages for overlapping regions, clamped to [0, 1].
"""
if not event_regions or not revenue_mix:
return 0.0
event_set = {r.upper() for r in event_regions}
overlap = 0.0
for region, pct in revenue_mix.items():
if region.upper() in event_set:
overlap += pct
return min(max(overlap, 0.0), 1.0)
def compute_supply_chain_overlap(
event_regions: list[str],
supply_regions: list[str],
) -> float:
"""Compute supply chain overlap using set intersection ratio.
Returns the fraction of the company's supply chain regions that
overlap with the event's affected regions.
Args:
event_regions: Region codes from the global event.
supply_regions: Company's supply_chain_regions.
Returns:
Intersection ratio in [0, 1]. 0.0 if supply_regions is empty.
"""
if not event_regions or not supply_regions:
return 0.0
event_set = {r.upper() for r in event_regions}
supply_set = {r.upper() for r in supply_regions}
intersection = event_set & supply_set
return len(intersection) / len(supply_set)
def compute_commodity_overlap(
event_commodities: list[str],
company_commodities: list[str],
) -> float:
"""Compute commodity overlap using set intersection ratio.
Returns the fraction of the company's key commodities that overlap
with the event's affected commodities.
Args:
event_commodities: Commodity identifiers from the global event.
company_commodities: Company's key_input_commodities.
Returns:
Intersection ratio in [0, 1]. 0.0 if company_commodities is empty.
"""
if not event_commodities or not company_commodities:
return 0.0
event_set = {c.lower() for c in event_commodities}
company_set = {c.lower() for c in company_commodities}
intersection = event_set & company_set
return len(intersection) / len(company_set)
# ---------------------------------------------------------------------------
# Resilience modifier
# ---------------------------------------------------------------------------
def apply_resilience_modifier(
raw_score: float,
tier: str,
event_is_international: bool = True,
) -> float:
"""Apply a resilience modifier based on market position tier.
For international events, global leaders get a dampening factor (0.7)
while domestic companies get an amplification factor (1.2).
For domestic-only events, no modifier is applied.
Args:
raw_score: The raw impact score before resilience adjustment.
tier: Market position tier value.
event_is_international: Whether the event affects multiple countries.
Returns:
Modified score clamped to [0, 1].
"""
if not event_is_international:
return min(max(raw_score, 0.0), 1.0)
modifier = RESILIENCE_MODIFIERS.get(tier, 1.0)
return min(max(raw_score * modifier, 0.0), 1.0)
# ---------------------------------------------------------------------------
# Impact direction determination
# ---------------------------------------------------------------------------
def _determine_impact_direction(
event_types: list[str],
) -> tuple[str, list[str], list[str]]:
"""Determine impact direction from event types.
Returns:
Tuple of (direction, positive_factors, negative_factors).
"""
positive_factors: list[str] = []
negative_factors: list[str] = []
for et in event_types:
if et in _NEGATIVE_EVENT_TYPES:
negative_factors.append(et)
elif et in _POSITIVE_EVENT_TYPES:
positive_factors.append(et)
elif et in _AMBIGUOUS_EVENT_TYPES:
# Ambiguous types contribute to both sides
positive_factors.append(et)
negative_factors.append(et)
has_positive = len(positive_factors) > 0
has_negative = len(negative_factors) > 0
if has_positive and has_negative:
return "mixed", positive_factors, negative_factors
elif has_positive:
return "positive", positive_factors, negative_factors
elif has_negative:
return "negative", positive_factors, negative_factors
else:
return "negative", positive_factors, negative_factors
# ---------------------------------------------------------------------------
# Core scoring function
# ---------------------------------------------------------------------------
def compute_macro_impact(
event: GlobalEvent,
profile: ExposureProfileSchema,
) -> MacroImpactRecord:
"""Compute the macro impact of a global event on a company.
Scoring formula:
raw_score = severity_weight * (
0.35 * geographic_overlap +
0.25 * supply_chain_overlap +
0.25 * commodity_overlap +
0.15 * sector_match
)
final_score = apply_resilience_modifier(raw_score, tier, is_international)
Args:
event: The classified global event.
profile: The company's exposure profile.
Returns:
A MacroImpactRecord with the computed score and metadata.
"""
now = datetime.now(timezone.utc)
# Compute overlaps
geo_overlap = compute_geographic_overlap(
event.affected_regions,
profile.geographic_revenue_mix,
)
supply_overlap = compute_supply_chain_overlap(
event.affected_regions,
profile.supply_chain_regions,
)
commodity_overlap = compute_commodity_overlap(
event.affected_commodities,
profile.key_input_commodities,
)
# Sector match: 1.0 if any event sector matches the company's sector
# We check against the profile's regulatory_jurisdictions as a proxy,
# but the real sector comes from the company data. For now, we use
# a simple heuristic: check if any affected_sectors appear in the
# profile's geographic_revenue_mix keys or supply_chain_regions.
# The actual sector is not stored in ExposureProfileSchema, so we
# check if any event sectors match. This will be 0.0 unless the
# caller provides sector info through contributing_factors.
sector_match = 0.0
# We'll compute sector_match based on event sectors — the caller
# should ensure the profile has relevant sector info. For the
# default implementation, we always set sector_match to 0.0 here
# and let the caller override if needed.
# Check zero-overlap case
contributing = []
if geo_overlap > 0:
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
if supply_overlap > 0:
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
if commodity_overlap > 0:
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
if total_overlap == 0.0:
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=0.0,
impact_direction="neutral",
contributing_factors=[],
confidence=0.0,
computed_at=now,
)
# Severity weight
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
# Raw score
raw_score = severity_weight * (
GEO_WEIGHT * geo_overlap
+ SUPPLY_WEIGHT * supply_overlap
+ COMMODITY_WEIGHT * commodity_overlap
+ SECTOR_WEIGHT * sector_match
)
# Determine if event is international (affects multiple regions)
is_international = len(event.affected_regions) > 1
# Apply resilience modifier
tier = profile.market_position_tier
if isinstance(tier, MarketPositionTier):
tier = tier.value
final_score = apply_resilience_modifier(raw_score, tier, is_international)
# Determine impact direction
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
# Build contributing factors list
all_factors = list(contributing)
if pos_factors:
all_factors.append(f"positive_types:{','.join(pos_factors)}")
if neg_factors:
all_factors.append(f"negative_types:{','.join(neg_factors)}")
# Confidence: combine event confidence with overlap strength
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=round(min(final_score, 1.0), 6),
impact_direction=direction,
contributing_factors=all_factors,
confidence=round(confidence, 6),
computed_at=now,
)
def compute_macro_impact_with_sector(
event: GlobalEvent,
profile: ExposureProfileSchema,
company_sector: str = "",
) -> MacroImpactRecord:
"""Compute macro impact with explicit sector matching.
Like compute_macro_impact but accepts a company_sector parameter
for proper sector_match computation.
Args:
event: The classified global event.
profile: The company's exposure profile.
company_sector: The company's GICS sector name.
Returns:
A MacroImpactRecord with the computed score and metadata.
"""
now = datetime.now(timezone.utc)
# Compute overlaps
geo_overlap = compute_geographic_overlap(
event.affected_regions,
profile.geographic_revenue_mix,
)
supply_overlap = compute_supply_chain_overlap(
event.affected_regions,
profile.supply_chain_regions,
)
commodity_overlap = compute_commodity_overlap(
event.affected_commodities,
profile.key_input_commodities,
)
# Sector match
sector_match = 0.0
if company_sector and event.affected_sectors:
company_sector_lower = company_sector.lower().strip()
for es in event.affected_sectors:
if es.lower().strip() == company_sector_lower:
sector_match = 1.0
break
# Contributing factors
contributing: list[str] = []
if geo_overlap > 0:
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
if supply_overlap > 0:
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
if commodity_overlap > 0:
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
if sector_match > 0:
contributing.append(f"sector_match:{company_sector}")
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
if total_overlap == 0.0:
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=0.0,
impact_direction="neutral",
contributing_factors=[],
confidence=0.0,
computed_at=now,
)
# Severity weight
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
# Raw score
raw_score = severity_weight * (
GEO_WEIGHT * geo_overlap
+ SUPPLY_WEIGHT * supply_overlap
+ COMMODITY_WEIGHT * commodity_overlap
+ SECTOR_WEIGHT * sector_match
)
# International check
is_international = len(event.affected_regions) > 1
# Resilience modifier
tier = profile.market_position_tier
if isinstance(tier, MarketPositionTier):
tier = tier.value
final_score = apply_resilience_modifier(raw_score, tier, is_international)
# Direction
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
all_factors = list(contributing)
if pos_factors:
all_factors.append(f"positive_types:{','.join(pos_factors)}")
if neg_factors:
all_factors.append(f"negative_types:{','.join(neg_factors)}")
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
return MacroImpactRecord(
event_id=event.event_id,
company_id=profile.company_id,
ticker="",
macro_impact_score=round(min(final_score, 1.0), 6),
impact_direction=direction,
contributing_factors=all_factors,
confidence=round(confidence, 6),
computed_at=now,
)
# ---------------------------------------------------------------------------
# Default profile builder
# ---------------------------------------------------------------------------
def build_default_profile(
sector: str,
industry: str,
market_cap_bucket: str,
) -> ExposureProfileSchema:
"""Build a default ExposureProfile for companies without manual profiles.
Uses sector-based geographic revenue defaults and maps market_cap_bucket
to market_position_tier:
large_cap → global_leader
mid_cap → multinational
small_cap → regional
micro_cap → domestic
Args:
sector: GICS sector name.
industry: Industry name (used for commodity defaults).
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
Returns:
An ExposureProfileSchema with source='inferred'.
"""
tier = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
geo_mix = _SECTOR_DEFAULT_GEO.get(sector, _DEFAULT_GEO)
# Derive supply chain regions from geo mix keys
supply_regions = list(geo_mix.keys())
# Derive commodities from sector/industry
commodities = _infer_commodities(sector, industry)
# Export dependency based on tier
export_pct = {
MarketPositionTier.GLOBAL_LEADER.value: 0.5,
MarketPositionTier.MULTINATIONAL.value: 0.35,
MarketPositionTier.REGIONAL.value: 0.15,
MarketPositionTier.DOMESTIC.value: 0.05,
}.get(tier, 0.15)
return ExposureProfileSchema(
company_id="",
geographic_revenue_mix=dict(geo_mix),
supply_chain_regions=supply_regions,
key_input_commodities=commodities,
regulatory_jurisdictions=list(geo_mix.keys())[:3],
market_position_tier=MarketPositionTier(tier),
export_dependency_pct=export_pct,
source="inferred",
confidence=0.5,
version=1,
)
def _infer_commodities(sector: str, industry: str) -> list[str]:
"""Infer key input commodities from sector and industry."""
sector_commodities: dict[str, list[str]] = {
"Energy": ["crude_oil", "natural_gas"],
"Materials": ["copper", "steel", "lithium"],
"Industrials": ["steel", "copper"],
"Information Technology": ["semiconductors", "lithium"],
"Consumer Staples": ["wheat", "corn"],
"Consumer Discretionary": ["steel", "semiconductors"],
"Health Care": [],
"Financials": [],
"Communication Services": [],
"Utilities": ["natural_gas"],
"Real Estate": ["steel"],
}
return sector_commodities.get(sector, [])
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
async def persist_macro_impact_record(
pool: asyncpg.Pool,
record: MacroImpactRecord,
) -> str:
"""Persist a MacroImpactRecord to the macro_impact_records table.
Returns the row UUID.
"""
row_id = await pool.fetchval(
"""INSERT INTO macro_impact_records
(event_id, company_id, ticker, macro_impact_score,
impact_direction, contributing_factors, confidence, computed_at)
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6::jsonb, $7, $8)
RETURNING id""",
record.event_id,
record.company_id,
record.ticker,
record.macro_impact_score,
record.impact_direction,
json.dumps(record.contributing_factors),
record.confidence,
record.computed_at,
)
logger.info(
"Persisted macro impact record for event=%s company=%s score=%.4f direction=%s",
record.event_id,
record.company_id,
record.macro_impact_score,
record.impact_direction,
)
return str(row_id)
async def persist_macro_impact_records(
pool: asyncpg.Pool,
records: list[MacroImpactRecord],
) -> list[str]:
"""Persist multiple MacroImpactRecords. Returns list of row UUIDs."""
ids: list[str] = []
for record in records:
if record.macro_impact_score > 0.0:
row_id = await persist_macro_impact_record(pool, record)
ids.append(row_id)
return ids
# ---------------------------------------------------------------------------
# Low-confidence event exclusion (Requirements: 10.1)
# ---------------------------------------------------------------------------
def filter_low_confidence_events(
events: list[GlobalEvent],
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
) -> list[GlobalEvent]:
"""Filter out events with confidence below the configurable threshold.
Events with confidence below the threshold are excluded from macro
impact computation and the exclusion reason is logged.
Args:
events: List of GlobalEvent classifications.
confidence_threshold: Minimum confidence for inclusion (default 0.4).
Returns:
List of events that pass the confidence threshold.
Requirements: 10.1
"""
included: list[GlobalEvent] = []
for event in events:
if event.confidence < confidence_threshold:
logger.info(
"Excluding low-confidence event %s: confidence=%.3f < threshold=%.3f",
event.event_id,
event.confidence,
confidence_threshold,
)
else:
included.append(event)
return included
# ---------------------------------------------------------------------------
# Accelerated decay for stale short-term events (Requirements: 10.2)
# ---------------------------------------------------------------------------
def compute_standard_recency_decay(
age_hours: float,
half_life_hours: float = 168.0,
) -> float:
"""Compute standard exponential recency decay.
Args:
age_hours: Age of the event in hours.
half_life_hours: Half-life for the decay function (default 7 days).
Returns:
Decay factor in (0, 1].
"""
import math
if age_hours <= 0:
return 1.0
return math.exp(-0.693 * age_hours / half_life_hours)
def apply_accelerated_decay(
age_hours: float,
estimated_duration: str,
staleness_hours: float = DEFAULT_SHORT_TERM_STALENESS_HOURS,
half_life_hours: float = 168.0,
) -> float:
"""Apply accelerated decay for stale short-term events.
For short_term events older than staleness_hours (default 48h),
the effective weight is strictly less than standard recency decay.
For non-short-term events or events within the staleness window,
standard recency decay is applied.
Args:
age_hours: Age of the event in hours.
estimated_duration: Event's estimated_duration field.
staleness_hours: Hours after which short-term events get accelerated decay.
half_life_hours: Half-life for the standard decay function.
Returns:
Effective signal weight in (0, 1].
Requirements: 10.2
"""
standard_decay = compute_standard_recency_decay(age_hours, half_life_hours)
if estimated_duration == "short_term" and age_hours > staleness_hours:
# Apply accelerated decay: multiply standard decay by a factor < 1
accelerated = standard_decay * ACCELERATED_DECAY_MULTIPLIER
logger.debug(
"Accelerated decay for short_term event: age=%.1fh, "
"standard=%.4f, accelerated=%.4f",
age_hours, standard_decay, accelerated,
)
return accelerated
return standard_decay
+125 -3
View File
@@ -1,4 +1,12 @@
"""Aggregation worker entrypoint - polls Redis for aggregation jobs."""
"""Aggregation worker entrypoint - polls Redis for aggregation jobs.
After computing trend summaries for a ticker, the worker also triggers
competitive signal propagation for the ticker's competitors when the
competitive layer is enabled. This ensures that document intelligence
for one company produces competitive signals for related companies.
Requirements: 4.1, 5.1, 9.4
"""
from __future__ import annotations
import asyncio
@@ -8,8 +16,9 @@ import logging
import asyncpg
import redis.asyncio as aioredis
from services.aggregation.worker import aggregate_company
from services.shared.config import load_config
from services.aggregation.signal_propagation import propagate_signals
from services.aggregation.worker import aggregate_company, fetch_competitive_enabled
from services.shared.config import CompetitiveConfig, load_config
from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import (
QUEUE_AGGREGATION,
@@ -20,6 +29,92 @@ from services.shared.redis_keys import (
logger = logging.getLogger("aggregation_main")
# ---------------------------------------------------------------------------
# Query to fetch recent document intelligence records for a ticker.
# Used to trigger signal propagation after aggregation completes.
# ---------------------------------------------------------------------------
_RECENT_INTELLIGENCE_QUERY = """
SELECT
di.document_id,
dir.catalyst_type,
dir.impact_score
FROM document_impact_records dir
JOIN document_intelligence di ON di.id = dir.intelligence_id
JOIN documents d ON d.id = di.document_id
WHERE dir.ticker = $1
AND di.validation_status = 'valid'
AND d.status != 'rejected'
ORDER BY d.published_at DESC
LIMIT 50
"""
# Track consecutive propagation failures for alerting (Requirement 9.4)
_propagation_consecutive_failures = 0
async def _trigger_signal_propagation(
pool: asyncpg.Pool,
ticker: str,
competitive_config: CompetitiveConfig,
) -> int:
"""Trigger competitive signal propagation for a ticker's recent documents.
Fetches recent document intelligence records for the ticker and calls
propagate_signals for each, producing competitive signals for the
ticker's competitors.
Returns the total number of competitive signals produced.
"""
global _propagation_consecutive_failures
rows = await pool.fetch(_RECENT_INTELLIGENCE_QUERY, ticker)
if not rows:
return 0
total_signals = 0
for row in rows:
document_id = str(row["document_id"])
catalyst_type = row["catalyst_type"] or "other"
impact_score = float(row["impact_score"] or 0.0)
if impact_score <= 0.0:
continue
try:
records = await propagate_signals(
pool=pool,
ticker=ticker,
catalyst_type=catalyst_type,
impact_score=impact_score,
document_id=document_id,
config=competitive_config,
)
total_signals += len(records)
# Reset failure counter on success
_propagation_consecutive_failures = 0
except Exception:
_propagation_consecutive_failures += 1
logger.exception(
"Signal propagation failed for %s doc %s/%s",
ticker, document_id, catalyst_type,
)
if _propagation_consecutive_failures >= competitive_config.propagation_failure_threshold:
logger.critical(
"ALERT: Sustained signal propagation failures (%d consecutive). "
"Continuing with company-specific + macro signals only. "
"Operator action required.",
_propagation_consecutive_failures,
)
# Stop trying propagation for this ticker after threshold
break
return total_signals
async def main() -> None:
config = load_config()
setup_logging("aggregation", level=config.log_level, json_output=config.json_logs)
@@ -28,6 +123,7 @@ async def main() -> None:
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_AGGREGATION)
rec_queue = queue_key(QUEUE_RECOMMENDATION)
competitive_config = config.competitive
logger.info("Aggregation worker started, polling %s", queue)
try:
@@ -49,6 +145,32 @@ async def main() -> None:
ticker, len(summaries),
)
# Trigger competitive signal propagation after aggregation
# (Requirement 4.1): When new document intelligence is
# produced for a company, propagate signals to competitors.
# Check toggle state from DB (same pattern as macro toggle).
competitive_enabled = competitive_config.competitive_enabled
db_toggle = await fetch_competitive_enabled(pool)
if db_toggle is not None:
competitive_enabled = db_toggle
if competitive_enabled:
try:
sig_count = await _trigger_signal_propagation(
pool, ticker, competitive_config,
)
if sig_count > 0:
logger.info(
"Propagated %d competitive signals for %s",
sig_count, ticker,
)
except Exception:
logger.exception(
"Signal propagation failed for %s"
"continuing with company+macro signals only",
ticker,
)
# Enqueue recommendation job for each window that produced a trend
for summary in summaries:
if summary.trend_strength > 0:
+414
View File
@@ -0,0 +1,414 @@
"""Historical pattern mining for competitive intelligence.
Queries document_impact_records joined with trend_windows to find how
similar catalyst types resolved historically for a company or its
competitors. Produces HistoricalPattern objects consumed by the signal
propagation engine and the aggregation worker.
Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 11.1, 11.2, 11.3, 11.5
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Optional
import asyncpg
from services.shared.config import CompetitiveConfig
from services.shared.schemas import MAJOR_DECISION_CATALYSTS
logger = logging.getLogger(__name__)
DEFAULT_HORIZONS = ["1d", "7d", "30d"]
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass
class HistoricalPattern:
"""Statistical summary of how a catalyst type resolved historically."""
source_ticker: str
target_ticker: str
catalyst_type: str
time_horizon: str # 1d | 7d | 30d
sample_count: int
bullish_pct: float # [0, 1]
bearish_pct: float # [0, 1]
avg_strength: float # [0, 1]
avg_time_to_resolution: float # days
pattern_confidence: float # [0, 1]
data_start: datetime
data_end: datetime
tier: str # major_corporate_decision | routine_signal
insufficient_data: bool
# ---------------------------------------------------------------------------
# Catalyst tier classification (Req 11.1)
# ---------------------------------------------------------------------------
def classify_catalyst_tier(catalyst_type: str) -> str:
"""Deterministic mapping of catalyst_type to tier.
Returns ``"major_corporate_decision"`` for catalyst types in
MAJOR_DECISION_CATALYSTS, otherwise ``"routine_signal"``.
"""
if catalyst_type in MAJOR_DECISION_CATALYSTS:
return "major_corporate_decision"
return "routine_signal"
# ---------------------------------------------------------------------------
# Pattern confidence (Req 3.3, 11.2)
# ---------------------------------------------------------------------------
def compute_pattern_confidence(
sample_count: int,
outcome_consistency: float,
data_recency_days: float,
tier: str,
config: Optional[CompetitiveConfig] = None,
) -> float:
"""Compute pattern confidence score in [0, 1].
Formula:
sample_factor * 0.4 + consistency * 0.4 + recency_factor * 0.2
With a 1.3× multiplier for ``major_corporate_decision`` tier,
insufficient-data cap, and staleness decay.
"""
cfg = config or CompetitiveConfig()
# --- component factors ---
sample_factor = min(sample_count / 20.0, 1.0)
consistency = outcome_consistency # already max(bullish_pct, bearish_pct)
if data_recency_days <= cfg.staleness_recent_days:
recency_factor = 1.0
elif data_recency_days <= cfg.staleness_window_days:
recency_factor = 0.7
else:
recency_factor = 0.4
confidence = sample_factor * 0.4 + consistency * 0.4 + recency_factor * 0.2
# Major-decision multiplier (Req 11.2)
if tier == "major_corporate_decision":
confidence *= cfg.major_decision_weight_multiplier
# Clamp to [0, 1]
confidence = min(max(confidence, 0.0), 1.0)
# Insufficient data cap (Req 3.4)
if sample_count < cfg.min_pattern_samples:
confidence = min(confidence, 0.25)
# Staleness decay (Req 9.2)
if data_recency_days > cfg.staleness_window_days:
confidence *= cfg.staleness_decay_penalty
return confidence
# ---------------------------------------------------------------------------
# Lookback helper
# ---------------------------------------------------------------------------
def _lookback_days(tier: str, config: Optional[CompetitiveConfig] = None) -> int:
"""Return the lookback window in days for the given tier."""
cfg = config or CompetitiveConfig()
if tier == "major_corporate_decision":
return cfg.major_decision_lookback_days
return cfg.routine_lookback_days
# ---------------------------------------------------------------------------
# SQL: self-company pattern query
# ---------------------------------------------------------------------------
_SELF_PATTERN_QUERY = """
WITH matched_docs AS (
SELECT
dir.id AS dir_id,
d.published_at,
dir.sentiment
FROM document_impact_records dir
JOIN document_intelligence di ON di.id = dir.intelligence_id
JOIN documents d ON d.id = di.document_id
WHERE dir.ticker = $1
AND dir.catalyst_type = $2
AND di.validation_status = 'valid'
AND d.status != 'rejected'
AND d.published_at >= $3
AND d.published_at <= $4
)
SELECT
md.dir_id,
md.published_at,
md.sentiment,
tw.trend_direction,
tw.trend_strength,
tw.generated_at,
tw."window" AS tw_window
FROM matched_docs md
JOIN trend_windows tw
ON tw.entity_type = 'company'
AND tw.entity_id = $1
AND tw."window" = $5
AND tw.generated_at >= md.published_at
AND tw.generated_at <= md.published_at + $6::interval
ORDER BY md.published_at DESC
"""
# ---------------------------------------------------------------------------
# SQL: cross-company pattern query
# ---------------------------------------------------------------------------
_CROSS_PATTERN_QUERY = """
WITH matched_docs AS (
SELECT
dir.id AS dir_id,
d.published_at,
dir.sentiment
FROM document_impact_records dir
JOIN document_intelligence di ON di.id = dir.intelligence_id
JOIN documents d ON d.id = di.document_id
WHERE dir.ticker = $1
AND dir.catalyst_type = $2
AND di.validation_status = 'valid'
AND d.status != 'rejected'
AND d.published_at >= $3
AND d.published_at <= $4
)
SELECT
md.dir_id,
md.published_at,
md.sentiment,
tw.trend_direction,
tw.trend_strength,
tw.generated_at,
tw."window" AS tw_window
FROM matched_docs md
JOIN trend_windows tw
ON tw.entity_type = 'company'
AND tw.entity_id = $5
AND tw."window" = $6
AND tw.generated_at >= md.published_at
AND tw.generated_at <= md.published_at + $7::interval
ORDER BY md.published_at DESC
"""
# ---------------------------------------------------------------------------
# Horizon → interval mapping
# ---------------------------------------------------------------------------
_HORIZON_INTERVALS: dict[str, str] = {
"1d": "1 day",
"7d": "7 days",
"30d": "30 days",
}
# ---------------------------------------------------------------------------
# Build HistoricalPattern from query rows
# ---------------------------------------------------------------------------
def _build_pattern(
rows: list[asyncpg.Record],
source_ticker: str,
target_ticker: str,
catalyst_type: str,
horizon: str,
tier: str,
config: Optional[CompetitiveConfig] = None,
) -> Optional[HistoricalPattern]:
"""Aggregate query rows into a single HistoricalPattern."""
if not rows:
return None
# De-duplicate by dir_id — keep the first (closest) trend_window per doc
seen: set[str] = set()
unique_rows: list[asyncpg.Record] = []
for r in rows:
rid = str(r["dir_id"])
if rid not in seen:
seen.add(rid)
unique_rows.append(r)
sample_count = len(unique_rows)
bullish = sum(1 for r in unique_rows if r["trend_direction"] == "bullish")
bearish = sum(1 for r in unique_rows if r["trend_direction"] == "bearish")
bullish_pct = bullish / sample_count
bearish_pct = bearish / sample_count
strengths = [float(r["trend_strength"]) for r in unique_rows if r["trend_strength"] is not None]
avg_strength = sum(strengths) / len(strengths) if strengths else 0.0
# avg_time_to_resolution: average days between published_at and generated_at
resolutions: list[float] = []
for r in unique_rows:
pub = r["published_at"]
gen = r["generated_at"]
if pub and gen:
delta = (gen - pub).total_seconds() / 86400.0
resolutions.append(max(delta, 0.0))
avg_time_to_resolution = sum(resolutions) / len(resolutions) if resolutions else 0.0
# Date range
published_dates = [r["published_at"] for r in unique_rows if r["published_at"] is not None]
data_start = min(published_dates)
data_end = max(published_dates)
# Recency: days since the most recent data point
now = datetime.now(timezone.utc)
data_recency_days = (now - data_end).total_seconds() / 86400.0 if data_end else 999.0
outcome_consistency = max(bullish_pct, bearish_pct)
confidence = compute_pattern_confidence(
sample_count, outcome_consistency, data_recency_days, tier, config,
)
insufficient_data = sample_count < (config or CompetitiveConfig()).min_pattern_samples
return HistoricalPattern(
source_ticker=source_ticker,
target_ticker=target_ticker,
catalyst_type=catalyst_type,
time_horizon=horizon,
sample_count=sample_count,
bullish_pct=bullish_pct,
bearish_pct=bearish_pct,
avg_strength=min(max(avg_strength, 0.0), 1.0),
avg_time_to_resolution=avg_time_to_resolution,
pattern_confidence=confidence,
data_start=data_start,
data_end=data_end,
tier=tier,
insufficient_data=insufficient_data,
)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def find_self_patterns(
pool: asyncpg.Pool,
ticker: str,
catalyst_type: str,
horizons: Optional[list[str]] = None,
config: Optional[CompetitiveConfig] = None,
) -> list[HistoricalPattern]:
"""Find historical patterns for the same company-catalyst pair.
Queries document_impact_records joined with trend_windows for the
given ticker and catalyst_type across configurable time horizons.
Requirements: 3.1, 3.2, 3.5, 11.3
"""
cfg = config or CompetitiveConfig()
horizons = horizons or DEFAULT_HORIZONS
tier = classify_catalyst_tier(catalyst_type)
lookback = _lookback_days(tier, cfg)
now = datetime.now(timezone.utc)
cutoff = now - timedelta(days=lookback)
patterns: list[HistoricalPattern] = []
async with pool.acquire() as conn:
for horizon in horizons:
interval = _HORIZON_INTERVALS.get(horizon)
if interval is None:
logger.warning("Unknown horizon %s, skipping", horizon)
continue
try:
rows = await conn.fetch(
_SELF_PATTERN_QUERY,
ticker, # $1
catalyst_type, # $2
cutoff, # $3
now, # $4
horizon, # $5
interval, # $6
)
except Exception:
logger.exception(
"Error querying self-patterns for %s/%s/%s",
ticker, catalyst_type, horizon,
)
continue
pattern = _build_pattern(
rows, ticker, ticker, catalyst_type, horizon, tier, cfg,
)
if pattern is not None:
patterns.append(pattern)
return patterns
async def find_cross_company_patterns(
pool: asyncpg.Pool,
source_ticker: str,
target_ticker: str,
catalyst_type: str,
horizons: Optional[list[str]] = None,
config: Optional[CompetitiveConfig] = None,
) -> list[HistoricalPattern]:
"""Find cross-company historical patterns.
Queries documents about *source_ticker* with the given catalyst_type,
then looks at trend_windows for *target_ticker* within each horizon
after the document was published.
Requirements: 4.2, 11.5
"""
cfg = config or CompetitiveConfig()
horizons = horizons or DEFAULT_HORIZONS
tier = classify_catalyst_tier(catalyst_type)
lookback = _lookback_days(tier, cfg)
now = datetime.now(timezone.utc)
cutoff = now - timedelta(days=lookback)
patterns: list[HistoricalPattern] = []
async with pool.acquire() as conn:
for horizon in horizons:
interval = _HORIZON_INTERVALS.get(horizon)
if interval is None:
logger.warning("Unknown horizon %s, skipping", horizon)
continue
try:
rows = await conn.fetch(
_CROSS_PATTERN_QUERY,
source_ticker, # $1
catalyst_type, # $2
cutoff, # $3
now, # $4
target_ticker, # $5
horizon, # $6
interval, # $7
)
except Exception:
logger.exception(
"Error querying cross-patterns for %s%s/%s/%s",
source_ticker, target_ticker, catalyst_type, horizon,
)
continue
pattern = _build_pattern(
rows, source_ticker, target_ticker, catalyst_type,
horizon, tier, cfg,
)
if pattern is not None:
patterns.append(pattern)
return patterns
+416
View File
@@ -0,0 +1,416 @@
"""Trend projection module — forward-looking trend estimates.
Computes TrendProjection objects by combining current trend momentum,
macro signal decay trajectories, and upcoming catalyst outlook.
Projections are persisted alongside trend_window records.
Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.9
"""
from __future__ import annotations
import json
import logging
import math
from dataclasses import dataclass, field
from datetime import datetime, timezone
import asyncpg
from services.shared.schemas import TrendDirection, TrendSummary
logger = logging.getLogger("projection")
# ---------------------------------------------------------------------------
# TrendProjection dataclass
# ---------------------------------------------------------------------------
VALID_DIRECTIONS = {"bullish", "bearish", "mixed", "neutral"}
VALID_HORIZONS = {"1d", "7d", "30d"}
# Default low-confidence threshold
DEFAULT_CONFIDENCE_THRESHOLD = 0.3
# Macro signal decay half-lives (in days) by estimated_duration
DECAY_HALF_LIFE_DAYS: dict[str, float] = {
"short_term": 1.0, # halve impact per day
"medium_term": 7.0, # halve impact per week
"long_term": 30.0, # halve impact per month
}
@dataclass
class TrendProjection:
"""Forward-looking trend projection for a company."""
projected_direction: str = "neutral" # bullish|bearish|mixed|neutral
projected_strength: float = 0.5 # [0, 1]
projected_confidence: float = 0.5 # [0, 1]
projection_horizon: str = "7d" # 1d|7d|30d
driving_factors: list[str] = field(default_factory=list)
macro_contribution_pct: float = 0.0 # [0, 1]
diverges_from_current: bool = False
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
low_confidence: bool = False
# ---------------------------------------------------------------------------
# Macro impact row type (lightweight, avoids circular import with worker)
# ---------------------------------------------------------------------------
@dataclass
class MacroEventInfo:
"""Minimal macro event info needed for projection computation."""
event_id: str = ""
macro_impact_score: float = 0.0
impact_direction: str = "neutral"
confidence: float = 0.5
estimated_duration: str = "short_term"
severity: str = "low"
event_age_hours: float = 0.0 # hours since event publication
# ---------------------------------------------------------------------------
# Projection horizon mapping from trend window
# ---------------------------------------------------------------------------
_WINDOW_TO_HORIZON: dict[str, str] = {
"intraday": "1d",
"1d": "1d",
"7d": "7d",
"30d": "30d",
"90d": "30d",
}
# ---------------------------------------------------------------------------
# Momentum computation
# ---------------------------------------------------------------------------
def compute_trend_momentum(
current_strength: float,
current_direction: str,
previous_strength: float | None = None,
previous_direction: str | None = None,
) -> float:
"""Compute trend momentum as rate of change in signed strength.
Returns a value in [-1, 1] representing the momentum:
- Positive = strengthening bullish or weakening bearish
- Negative = strengthening bearish or weakening bullish
- Zero = no change or no previous data
When no previous data is available, uses a simple heuristic based
on current strength and direction.
"""
dir_sign = _direction_sign(current_direction)
if previous_strength is None or previous_direction is None:
# Heuristic: assume momentum proportional to current signed strength
return round(dir_sign * current_strength * 0.5, 6)
prev_sign = _direction_sign(previous_direction)
current_signed = dir_sign * current_strength
previous_signed = prev_sign * previous_strength
momentum = current_signed - previous_signed
return round(max(-1.0, min(1.0, momentum)), 6)
def _direction_sign(direction: str) -> float:
"""Map direction to a sign multiplier."""
if direction == "bullish":
return 1.0
elif direction == "bearish":
return -1.0
return 0.0
# ---------------------------------------------------------------------------
# Macro signal decay projection
# ---------------------------------------------------------------------------
_SEVERITY_WEIGHT: dict[str, float] = {
"critical": 1.0,
"high": 0.75,
"moderate": 0.5,
"low": 0.25,
}
def project_macro_decay(
events: list[MacroEventInfo],
horizon_days: float,
) -> tuple[float, str]:
"""Project the aggregate macro signal after decay over the horizon.
For each active macro event, compute the projected remaining impact
using exponential decay based on estimated_duration:
- short_term: half-life = 1 day
- medium_term: half-life = 7 days
- long_term: half-life = 30 days
Returns:
(projected_macro_strength, projected_macro_direction)
where strength is in [0, 1] and direction is bullish|bearish|mixed|neutral.
"""
if not events:
return 0.0, "neutral"
positive_weight = 0.0
negative_weight = 0.0
for ev in events:
half_life = DECAY_HALF_LIFE_DAYS.get(ev.estimated_duration, 7.0)
# Current age in days
current_age_days = ev.event_age_hours / 24.0
# Projected age at end of horizon
future_age_days = current_age_days + horizon_days
# Decay factor: ratio of future impact to current impact
if half_life > 0:
current_factor = math.pow(2.0, -current_age_days / half_life)
future_factor = math.pow(2.0, -future_age_days / half_life)
else:
current_factor = 0.0
future_factor = 0.0
severity_w = _SEVERITY_WEIGHT.get(ev.severity, 0.25)
projected_impact = ev.macro_impact_score * future_factor * severity_w
if ev.impact_direction == "positive":
positive_weight += projected_impact
elif ev.impact_direction == "negative":
negative_weight += projected_impact
else:
# mixed/neutral: split evenly
positive_weight += projected_impact * 0.5
negative_weight += projected_impact * 0.5
total = positive_weight + negative_weight
if total == 0.0:
return 0.0, "neutral"
strength = min(total, 1.0)
if positive_weight > negative_weight * 1.2:
direction = "bullish"
elif negative_weight > positive_weight * 1.2:
direction = "bearish"
elif positive_weight > 0 and negative_weight > 0:
direction = "mixed"
else:
direction = "neutral"
return round(strength, 6), direction
# ---------------------------------------------------------------------------
# Horizon days mapping
# ---------------------------------------------------------------------------
_HORIZON_DAYS: dict[str, float] = {
"1d": 1.0,
"7d": 7.0,
"30d": 30.0,
}
# ---------------------------------------------------------------------------
# Core projection computation
# ---------------------------------------------------------------------------
def compute_projection(
summary: TrendSummary,
macro_events: list[MacroEventInfo] | None = None,
macro_enabled: bool = True,
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
previous_strength: float | None = None,
previous_direction: str | None = None,
upcoming_catalysts: list[str] | None = None,
) -> TrendProjection:
"""Compute a forward-looking trend projection.
Combines:
1. Trend momentum (rate of change in strength)
2. Macro signal decay projection
3. Upcoming catalyst outlook
4. Current trend baseline
Args:
summary: The current trend summary.
macro_events: Active macro events with their info.
macro_enabled: Whether the macro layer is enabled.
confidence_threshold: Below this, mark as low_confidence.
previous_strength: Previous window's trend strength (optional).
previous_direction: Previous window's trend direction (optional).
upcoming_catalysts: Known upcoming catalysts from doc intelligence.
Returns:
A TrendProjection with projected direction, strength, and confidence.
"""
now = datetime.now(timezone.utc)
current_dir = summary.trend_direction.value
current_strength = summary.trend_strength
current_confidence = summary.confidence
horizon = _WINDOW_TO_HORIZON.get(summary.window.value, "7d")
horizon_days = _HORIZON_DAYS.get(horizon, 7.0)
driving_factors: list[str] = []
# 1. Compute trend momentum
momentum = compute_trend_momentum(
current_strength, current_dir,
previous_strength, previous_direction,
)
if abs(momentum) > 0.05:
if momentum > 0:
driving_factors.append(f"Positive momentum ({momentum:+.3f}) in recent trend strength")
else:
driving_factors.append(f"Negative momentum ({momentum:+.3f}) in recent trend strength")
# 2. Project macro signal decay
macro_strength = 0.0
macro_direction = "neutral"
macro_contribution = 0.0
if macro_enabled and macro_events:
macro_strength, macro_direction = project_macro_decay(macro_events, horizon_days)
if macro_strength > 0:
driving_factors.append(
f"Macro signals project {macro_direction} impact "
f"(strength {macro_strength:.3f}) over {horizon}"
)
# 3. Factor in upcoming catalysts
catalysts = upcoming_catalysts or []
for catalyst in catalysts[:3]: # limit to top 3
driving_factors.append(f"Upcoming catalyst: {catalyst}")
catalyst_boost = min(len(catalysts) * 0.02, 0.1) # small boost per catalyst
# 4. Combine into projected direction/strength/confidence
# Momentum-based projection of company-specific trend
momentum_projected_signed = _direction_sign(current_dir) * current_strength + momentum * 0.5
momentum_projected_strength = min(abs(momentum_projected_signed), 1.0)
if macro_enabled and macro_strength > 0:
# Blend company momentum with macro trajectory
macro_weight = min(macro_strength * 0.4, 0.4)
company_weight = 1.0 - macro_weight
macro_signed = _direction_sign(macro_direction) * macro_strength
blended_signed = (
company_weight * momentum_projected_signed
+ macro_weight * macro_signed
)
projected_strength = round(min(abs(blended_signed) + catalyst_boost, 1.0), 6)
macro_contribution = round(macro_weight, 6)
# Determine projected direction from blended signal
projected_direction = _signed_to_direction(blended_signed)
else:
# Company-only projection
projected_strength = round(min(momentum_projected_strength + catalyst_boost, 1.0), 6)
projected_direction = _signed_to_direction(momentum_projected_signed)
# Compute projected confidence
base_confidence = current_confidence * 0.8 # projection inherently less certain
if macro_enabled and macro_strength > 0:
# Macro data adds information → slight confidence boost
macro_conf_boost = min(macro_strength * 0.15, 0.1)
projected_confidence = round(min(base_confidence + macro_conf_boost, 1.0), 6)
else:
# Without macro data, reduce confidence further
if not macro_enabled:
projected_confidence = round(base_confidence * 0.85, 6)
else:
projected_confidence = round(base_confidence, 6)
# Ensure driving_factors is never empty
if not driving_factors:
driving_factors.append(f"Baseline trend continuation: {current_dir} at strength {current_strength:.3f}")
# 5. Flag divergence
diverges = projected_direction != current_dir
if diverges:
driving_factors.append(
f"DIVERGENCE: Current trend is {current_dir}, "
f"projection is {projected_direction}"
)
# Mark low confidence
is_low_confidence = projected_confidence < confidence_threshold
return TrendProjection(
projected_direction=projected_direction,
projected_strength=projected_strength,
projected_confidence=projected_confidence,
projection_horizon=horizon,
driving_factors=driving_factors,
macro_contribution_pct=macro_contribution,
diverges_from_current=diverges,
computed_at=now,
low_confidence=is_low_confidence,
)
def _signed_to_direction(signed_value: float) -> str:
"""Convert a signed strength value to a direction string."""
if signed_value > 0.1:
return "bullish"
elif signed_value < -0.1:
return "bearish"
elif abs(signed_value) > 0.02:
return "mixed"
return "neutral"
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
_INSERT_PROJECTION = """
INSERT INTO trend_projections (
trend_window_id, projected_direction, projected_strength,
projected_confidence, projection_horizon, driving_factors,
macro_contribution_pct, diverges_from_current, computed_at
) VALUES (
$1::uuid, $2, $3, $4, $5, $6::jsonb, $7, $8, $9
)
RETURNING id
"""
async def persist_trend_projection(
pool: asyncpg.Pool,
trend_window_id: str,
projection: TrendProjection,
) -> str:
"""Persist a TrendProjection to the trend_projections table.
Returns the row UUID.
"""
row_id = await pool.fetchval(
_INSERT_PROJECTION,
trend_window_id,
projection.projected_direction,
projection.projected_strength,
projection.projected_confidence,
projection.projection_horizon,
json.dumps(projection.driving_factors),
projection.macro_contribution_pct,
projection.diverges_from_current,
projection.computed_at,
)
logger.info(
"Persisted trend projection for window=%s: direction=%s strength=%.3f confidence=%.3f diverges=%s",
trend_window_id,
projection.projected_direction,
projection.projected_strength,
projection.projected_confidence,
projection.diverges_from_current,
)
return str(row_id)
+226 -13
View File
@@ -4,13 +4,13 @@ Aggregates company-level trend summaries into sector and market-level
summaries, enabling top-down views of sentiment and risk across the
portfolio.
Requirements: 6.3, 6.4, 6.5
Requirements: 6.1, 6.2, 6.3, 6.4, 6.5
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
import asyncpg
@@ -42,6 +42,126 @@ class CompanyTrendRow:
top_opposing_evidence: list[str]
@dataclass
class SectorMacroImpact:
"""Aggregated macro impact data for a single sector.
Used to incorporate macro signals into sector and market rollups.
Requirements: 6.1, 6.2, 6.3
"""
sector: str
total_impact: float # sum of macro_impact_score across companies in sector
avg_impact: float # average macro_impact_score
company_count: int # number of companies affected
net_direction: float # weighted direction: +1 positive, -1 negative, 0 mixed
event_ids: list[str] = field(default_factory=list) # contributing event IDs
# Threshold for disproportionate sector impact (Requirement 6.3)
SECTOR_CONCENTRATION_THRESHOLD = 0.60
# ---------------------------------------------------------------------------
# Fetch sector-level macro impact aggregates
# ---------------------------------------------------------------------------
_SECTOR_MACRO_IMPACT_QUERY = """
SELECT
c.sector,
mir.event_id,
mir.macro_impact_score,
mir.impact_direction
FROM macro_impact_records mir
JOIN companies c ON c.id = mir.company_id AND c.active = TRUE
WHERE mir.computed_at >= $1
AND mir.computed_at <= $2
ORDER BY c.sector, mir.macro_impact_score DESC
"""
async def fetch_sector_macro_impacts(
pool: asyncpg.Pool,
window_start: datetime,
window_end: datetime,
) -> dict[str, SectorMacroImpact]:
"""Fetch macro impact records aggregated by sector for a time range.
Returns a mapping of sector name to SectorMacroImpact.
"""
rows = await pool.fetch(_SECTOR_MACRO_IMPACT_QUERY, window_start, window_end)
# Accumulate per-sector
sector_data: dict[str, dict] = {}
direction_map = {"positive": 1.0, "negative": -1.0, "mixed": 0.0, "neutral": 0.0}
for row in rows:
sector = str(row["sector"]) if row["sector"] else "Unknown"
score = float(row["macro_impact_score"] or 0.0)
direction = row["impact_direction"] or "neutral"
event_id = str(row["event_id"])
if sector not in sector_data:
sector_data[sector] = {
"total": 0.0,
"count": 0,
"dir_sum": 0.0,
"dir_count": 0,
"event_ids": set(),
}
d = sector_data[sector]
d["total"] += score
d["count"] += 1
dir_val = direction_map.get(direction, 0.0)
if dir_val != 0.0:
d["dir_sum"] += dir_val
d["dir_count"] += 1
d["event_ids"].add(event_id)
result: dict[str, SectorMacroImpact] = {}
for sector, d in sector_data.items():
count = d["count"]
avg = d["total"] / count if count > 0 else 0.0
net_dir = d["dir_sum"] / d["dir_count"] if d["dir_count"] > 0 else 0.0
result[sector] = SectorMacroImpact(
sector=sector,
total_impact=d["total"],
avg_impact=avg,
company_count=count,
net_direction=net_dir,
event_ids=sorted(d["event_ids"]),
)
return result
# ---------------------------------------------------------------------------
# Sector macro concentration helper (Requirement 6.3)
# ---------------------------------------------------------------------------
def compute_sector_macro_concentration(
sector_impacts: dict[str, SectorMacroImpact],
) -> list[tuple[str, float]]:
"""Compute the fraction of total macro impact concentrated in each sector.
Returns a list of (sector, fraction) tuples sorted by fraction descending.
Sectors with fraction > SECTOR_CONCENTRATION_THRESHOLD are considered
disproportionately affected.
"""
total = sum(si.total_impact for si in sector_impacts.values())
if total <= 0.0:
return []
fractions = [
(sector, si.total_impact / total)
for sector, si in sector_impacts.items()
]
fractions.sort(key=lambda x: x[1], reverse=True)
return fractions
# ---------------------------------------------------------------------------
# Fetch latest company trends for a given window
# ---------------------------------------------------------------------------
@@ -141,11 +261,22 @@ def rollup_trends(
entity_id: str,
window: str,
reference_time: datetime,
macro_impacts: dict[str, SectorMacroImpact] | None = None,
) -> TrendSummary:
"""Aggregate a list of company-level trends into a single rollup summary.
Each company trend is weighted by its confidence to produce a
confidence-weighted average of direction, strength, and contradiction.
When macro_impacts is provided:
- For sector rollups: incorporates the sector's macro signal into
strength and confidence, weighted by constituent company exposure.
- For market rollups: aggregates macro signals across all sectors and
surfaces disproportionately affected sectors (>60% concentration)
in material_risks or dominant_catalysts.
When macro_impacts is None or empty, produces identical output to
the original company-only rollup.
"""
if not trends:
return TrendSummary(
@@ -204,16 +335,70 @@ def rollup_trends(
avg_contradiction = weighted_contradiction / total_weight
avg_confidence = total_weight / len(trends)
# --- Incorporate macro impact signals when available ---
macro_strength_adj = 0.0
macro_confidence_adj = 0.0
macro_catalysts: list[str] = []
macro_risks: list[str] = []
if macro_impacts:
if entity_type == "sector":
# Sector rollup: incorporate this sector's macro signal
sector_macro = macro_impacts.get(entity_id)
if sector_macro and sector_macro.total_impact > 0:
# Weight macro contribution by avg impact and company breadth
breadth = min(sector_macro.company_count / max(len(trends), 1), 1.0)
macro_strength_adj = sector_macro.avg_impact * breadth * 0.3
macro_confidence_adj = sector_macro.avg_impact * breadth * 0.1
# Nudge direction based on macro net direction
avg_direction += sector_macro.net_direction * macro_strength_adj * 0.5
elif entity_type == "market":
# Market rollup: aggregate macro signals across all sectors
total_macro = sum(si.total_impact for si in macro_impacts.values())
if total_macro > 0:
total_companies = sum(si.company_count for si in macro_impacts.values())
breadth = min(total_companies / max(len(trends), 1), 1.0)
avg_macro = total_macro / max(len(macro_impacts), 1)
macro_strength_adj = avg_macro * breadth * 0.3
macro_confidence_adj = avg_macro * breadth * 0.1
# Aggregate net direction across sectors
dir_sum = sum(
si.net_direction * si.total_impact
for si in macro_impacts.values()
)
net_dir = dir_sum / total_macro if total_macro > 0 else 0.0
avg_direction += net_dir * macro_strength_adj * 0.5
# Surface disproportionately affected sectors (Requirement 6.3)
concentration = compute_sector_macro_concentration(macro_impacts)
for sector, fraction in concentration:
if fraction > SECTOR_CONCENTRATION_THRESHOLD:
si = macro_impacts[sector]
label = f"Macro: {sector} ({fraction:.0%} of macro impact)"
if si.net_direction < 0:
macro_risks.append(label)
else:
macro_catalysts.append(label)
# Apply macro adjustments to strength and confidence
adj_strength = avg_strength + macro_strength_adj
adj_confidence = avg_confidence + macro_confidence_adj
# Derive direction
direction = _derive_rollup_direction(avg_direction, avg_contradiction)
# Top catalysts
# Top catalysts (macro catalysts prepended when present)
sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True)
catalysts = [c for c, _ in sorted_catalysts[:5]]
catalysts = macro_catalysts + [c for c, _ in sorted_catalysts[:5]]
catalysts = catalysts[:5]
# Top risks (deduplicated, by weight)
# Top risks (macro risks prepended when present, deduplicated)
sorted_risks = sorted(risk_set.items(), key=lambda x: x[1], reverse=True)
risks = [r for r, _ in sorted_risks[:5]]
base_risks = [r for r, _ in sorted_risks[:5]]
risks = macro_risks + base_risks
risks = risks[:5]
# Disagreement details
disagreement = _build_rollup_disagreement(trends, entity_id)
@@ -223,8 +408,8 @@ def rollup_trends(
entity_id=entity_id,
window=TrendWindow(window),
trend_direction=direction,
trend_strength=round(min(abs(avg_strength), 1.0), 4),
confidence=round(max(0.0, min(avg_confidence, 1.0)), 4),
trend_strength=round(min(abs(adj_strength), 1.0), 4),
confidence=round(max(0.0, min(adj_confidence, 1.0)), 4),
top_supporting_evidence=list(dict.fromkeys(all_supporting))[:10],
top_opposing_evidence=list(dict.fromkeys(all_opposing))[:10],
dominant_catalysts=catalysts,
@@ -341,11 +526,14 @@ async def aggregate_sector(
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
macro_impacts: dict[str, SectorMacroImpact] | None = None,
) -> TrendSummary:
"""Compute and persist a sector-level rollup for one window.
Fetches the latest company trends, filters to the given sector,
and rolls them up into a single sector summary.
and rolls them up into a single sector summary. When macro_impacts
is provided, incorporates macro signals weighted by constituent
company exposure.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
@@ -355,7 +543,14 @@ async def aggregate_sector(
all_trends = await fetch_latest_company_trends(pool, window, since)
sector_trends = [t for t in all_trends if t.sector == sector]
summary = rollup_trends(sector_trends, "sector", sector, window, reference_time)
# Fetch macro impacts if not provided
if macro_impacts is None:
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
summary = rollup_trends(
sector_trends, "sector", sector, window, reference_time,
macro_impacts=macro_impacts,
)
if sector_trends:
rollup_id = await persist_rollup(pool, summary)
@@ -373,10 +568,13 @@ async def aggregate_market(
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
macro_impacts: dict[str, SectorMacroImpact] | None = None,
) -> TrendSummary:
"""Compute and persist a market-wide rollup for one window.
Aggregates all company trends regardless of sector.
Aggregates all company trends regardless of sector. When macro_impacts
is provided, aggregates macro signals across all sectors and surfaces
disproportionately affected sectors in material_risks or dominant_catalysts.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
@@ -385,7 +583,14 @@ async def aggregate_market(
all_trends = await fetch_latest_company_trends(pool, window, since)
summary = rollup_trends(all_trends, "market", "all", window, reference_time)
# Fetch macro impacts if not provided
if macro_impacts is None:
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
summary = rollup_trends(
all_trends, "market", "all", window, reference_time,
macro_impacts=macro_impacts,
)
if all_trends:
rollup_id = await persist_rollup(pool, summary)
@@ -403,6 +608,7 @@ async def aggregate_all_sectors(
window: str,
reference_time: datetime | None = None,
since: datetime | None = None,
macro_impacts: dict[str, SectorMacroImpact] | None = None,
) -> list[TrendSummary]:
"""Compute sector rollups for every sector that has company trends."""
if reference_time is None:
@@ -412,6 +618,10 @@ async def aggregate_all_sectors(
all_trends = await fetch_latest_company_trends(pool, window, since)
# Fetch macro impacts once for all sectors if not provided
if macro_impacts is None:
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
# Group by sector
sectors: dict[str, list[CompanyTrendRow]] = {}
for t in all_trends:
@@ -419,7 +629,10 @@ async def aggregate_all_sectors(
summaries: list[TrendSummary] = []
for sector, trends in sectors.items():
summary = rollup_trends(trends, "sector", sector, window, reference_time)
summary = rollup_trends(
trends, "sector", sector, window, reference_time,
macro_impacts=macro_impacts,
)
if trends:
_id = await persist_rollup(pool, summary)
summaries.append(summary)
+306
View File
@@ -0,0 +1,306 @@
"""Competitive signal propagation engine.
Evaluates incoming document intelligence, identifies competitors via
the competitor_relationships table, queries historical cross-company
patterns, and produces weighted competitive signals persisted to
competitive_signal_records.
Also converts pattern and competitive signals into WeightedSignal
objects for the aggregation engine.
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 9.1
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Optional
import asyncpg
from services.aggregation.pattern_matcher import (
HistoricalPattern,
find_cross_company_patterns,
)
from services.aggregation.scoring import (
ScoringConfig,
WeightedSignal,
compute_signal_weight,
)
from services.shared.config import CompetitiveConfig
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data classes
# ---------------------------------------------------------------------------
@dataclass
class CompetitiveSignalRecord:
"""A competitive signal produced by propagating a source event to a
competitor based on historical cross-company patterns."""
source_document_id: str
source_ticker: str
target_ticker: str
catalyst_type: str
pattern_confidence: float
signal_direction: str # bullish | bearish
signal_strength: float # [0, 1]
relationship_strength: float
computed_at: datetime
# ---------------------------------------------------------------------------
# SQL queries
# ---------------------------------------------------------------------------
_COMPETITOR_LOOKUP_QUERY = """
SELECT company_a_id, company_b_id, strength
FROM competitor_relationships
WHERE (company_a_id = $1 OR company_b_id = $1)
AND active = TRUE
"""
_INSERT_SIGNAL_QUERY = """
INSERT INTO competitive_signal_records
(source_document_id, source_ticker, target_ticker, catalyst_type,
pattern_confidence, signal_direction, signal_strength,
relationship_strength, computed_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
"""
# ---------------------------------------------------------------------------
# propagate_signals
# ---------------------------------------------------------------------------
async def propagate_signals(
pool: asyncpg.Pool,
ticker: str,
catalyst_type: str,
impact_score: float,
document_id: str,
config: Optional[CompetitiveConfig] = None,
) -> list[CompetitiveSignalRecord]:
"""Look up competitors, query cross-company patterns, produce weighted
competitive signals, and persist them.
Args:
pool: asyncpg connection pool.
ticker: Source company ticker that received the catalyst.
catalyst_type: The catalyst type from document intelligence.
impact_score: The source document's impact score.
document_id: The source document ID.
config: Optional competitive config overrides.
Returns:
List of CompetitiveSignalRecord objects produced and persisted.
"""
cfg = config or CompetitiveConfig()
now = datetime.now(timezone.utc)
records: list[CompetitiveSignalRecord] = []
# Step 1: Look up active competitors
try:
async with pool.acquire() as conn:
rows = await conn.fetch(_COMPETITOR_LOOKUP_QUERY, ticker)
except Exception:
logger.exception("Failed to look up competitors for %s", ticker)
return records
if not rows:
logger.debug("No active competitors found for %s", ticker)
return records
# Step 2: For each competitor, query cross-company patterns
for row in rows:
company_a = str(row["company_a_id"])
company_b = str(row["company_b_id"])
rel_strength = float(row["strength"])
# Determine the competitor ticker (the other side of the relationship)
competitor_ticker = company_b if company_a == ticker else company_a
# Threshold gating (Req 4.5)
if rel_strength < cfg.propagation_strength_threshold:
logger.info(
"Skipping propagation %s%s: relationship strength %.3f "
"below threshold %.3f",
ticker, competitor_ticker, rel_strength,
cfg.propagation_strength_threshold,
)
continue
# Query cross-company patterns
try:
patterns = await find_cross_company_patterns(
pool, ticker, competitor_ticker, catalyst_type, config=cfg,
)
except Exception:
logger.exception(
"Failed to query cross-company patterns for %s%s/%s",
ticker, competitor_ticker, catalyst_type,
)
continue
for pattern in patterns:
# Confidence threshold gating (Req 9.1)
if pattern.pattern_confidence < cfg.pattern_confidence_threshold:
logger.info(
"Excluding pattern %s%s/%s/%s: confidence %.3f "
"below threshold %.3f",
ticker, competitor_ticker, catalyst_type,
pattern.time_horizon, pattern.pattern_confidence,
cfg.pattern_confidence_threshold,
)
continue
# Compute signal strength (Req 4.3)
raw_strength = (
pattern.avg_strength
* rel_strength
* pattern.pattern_confidence
* impact_score
)
signal_strength = min(max(raw_strength, 0.0), 1.0)
# Determine direction
direction = (
"bullish" if pattern.bullish_pct > pattern.bearish_pct
else "bearish"
)
record = CompetitiveSignalRecord(
source_document_id=document_id,
source_ticker=ticker,
target_ticker=competitor_ticker,
catalyst_type=catalyst_type,
pattern_confidence=pattern.pattern_confidence,
signal_direction=direction,
signal_strength=signal_strength,
relationship_strength=rel_strength,
computed_at=now,
)
records.append(record)
# Step 3: Persist all records
if records:
try:
async with pool.acquire() as conn:
await conn.executemany(
_INSERT_SIGNAL_QUERY,
[
(
r.source_document_id,
r.source_ticker,
r.target_ticker,
r.catalyst_type,
r.pattern_confidence,
r.signal_direction,
r.signal_strength,
r.relationship_strength,
r.computed_at,
)
for r in records
],
)
except Exception:
logger.exception(
"Failed to persist %d competitive signal records", len(records),
)
return records
# ---------------------------------------------------------------------------
# build_pattern_weighted_signals
# ---------------------------------------------------------------------------
def build_pattern_weighted_signals(
patterns: list[HistoricalPattern],
competitive_signals: list[CompetitiveSignalRecord],
reference_time: datetime,
window: str,
config: Optional[CompetitiveConfig] = None,
) -> list[WeightedSignal]:
"""Convert pattern and competitive signal objects to WeightedSignal
objects for the aggregation engine.
For HistoricalPattern objects:
- sentiment_value = +1.0 if bullish_pct > bearish_pct else -1.0
- impact_score = avg_strength * competitive_signal_weight
- published_at = data_end (most recent data point for recency decay)
- extraction_confidence = pattern_confidence
For CompetitiveSignalRecord objects:
- sentiment_value = +1.0 if direction == "bullish" else -1.0
- impact_score = signal_strength * competitive_signal_weight
- published_at = computed_at (for recency decay)
- extraction_confidence = pattern_confidence
Args:
patterns: Self-company historical patterns.
competitive_signals: Competitive signal records from propagation.
reference_time: Aggregation anchor time for recency decay.
window: Trend window identifier (e.g. "7d").
config: Optional competitive config overrides.
Returns:
List of WeightedSignal objects ready for aggregation.
"""
cfg = config or CompetitiveConfig()
scoring_cfg = ScoringConfig()
signals: list[WeightedSignal] = []
# Convert HistoricalPattern objects
for pattern in patterns:
sentiment_value = (
1.0 if pattern.bullish_pct > pattern.bearish_pct else -1.0
)
impact = pattern.avg_strength * cfg.competitive_signal_weight
weight = compute_signal_weight(
published_at=pattern.data_end,
reference_time=reference_time,
window=window,
source_credibility=1.0, # patterns are derived from validated data
novelty_score=0.5,
extraction_confidence=pattern.pattern_confidence,
market_ctx=None,
config=scoring_cfg,
)
signals.append(WeightedSignal(
document_id=f"pattern:{pattern.source_ticker}:{pattern.catalyst_type}:{pattern.time_horizon}",
weight=weight,
sentiment_value=sentiment_value,
impact_score=impact,
))
# Convert CompetitiveSignalRecord objects
for sig in competitive_signals:
sentiment_value = 1.0 if sig.signal_direction == "bullish" else -1.0
impact = sig.signal_strength * cfg.competitive_signal_weight
weight = compute_signal_weight(
published_at=sig.computed_at,
reference_time=reference_time,
window=window,
source_credibility=1.0,
novelty_score=0.5,
extraction_confidence=sig.pattern_confidence,
market_ctx=None,
config=scoring_cfg,
)
signals.append(WeightedSignal(
document_id=sig.source_document_id,
weight=weight,
sentiment_value=sentiment_value,
impact_score=impact,
))
return signals
+410 -5
View File
@@ -40,6 +40,17 @@ from services.shared.metrics import (
AGGREGATION_SIGNALS_PROCESSED,
AGGREGATION_WINDOWS_COMPUTED,
)
from services.aggregation.pattern_matcher import find_self_patterns
from services.aggregation.projection import (
MacroEventInfo,
TrendProjection,
compute_projection,
persist_trend_projection,
)
from services.aggregation.signal_propagation import (
CompetitiveSignalRecord,
build_pattern_weighted_signals,
)
from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow
logger = logging.getLogger(__name__)
@@ -64,6 +75,10 @@ class AggregationConfig:
windows: list[str] | None = None # None = all windows
scoring: ScoringConfig | None = None
max_evidence: int = MAX_EVIDENCE_REFS
macro_signal_weight: float = 0.3 # relative weight of macro vs company signals
macro_enabled: bool = True # runtime toggle state
competitive_signal_weight: float = 0.2 # relative weight of pattern signals
competitive_enabled: bool = True # runtime toggle state
def effective_windows(self) -> list[str]:
if self.windows:
@@ -154,6 +169,236 @@ async def fetch_impact_records(
# ---------------------------------------------------------------------------
# Fetch macro toggle state from risk_configs
#
# MACRO LAYER TOGGLE BEHAVIOR (Requirements 11.2, 11.3, 11.4):
# - The toggle state is read fresh from PostgreSQL at the start of each
# aggregation cycle (no caching), so changes take effect immediately on
# the next cycle.
# - When disabled: ingestion and classification continue normally (historical
# data is preserved), but interpolation and aggregation integration are
# skipped — the aggregation engine produces trends using only company-
# specific signals.
# - When re-enabled: the engine resumes computing macro impact scores using
# the most recent GlobalEvent classifications, including any events that
# were ingested and classified while the layer was disabled.
# ---------------------------------------------------------------------------
_MACRO_TOGGLE_QUERY = """
SELECT config->>'macro_enabled' AS macro_enabled
FROM risk_configs
WHERE active = TRUE
ORDER BY updated_at DESC
LIMIT 1
"""
async def fetch_macro_enabled(pool: asyncpg.Pool) -> bool | None:
"""Check macro toggle state from risk_configs table.
Returns True/False if explicitly set, or None if no config exists
(caller should fall back to AggregationConfig default).
"""
row = await pool.fetchrow(_MACRO_TOGGLE_QUERY)
if row is None or row["macro_enabled"] is None:
return None
return row["macro_enabled"].lower() == "true"
# ---------------------------------------------------------------------------
# Fetch competitive toggle state from risk_configs
# ---------------------------------------------------------------------------
_COMPETITIVE_TOGGLE_QUERY = """
SELECT config->>'competitive_enabled' AS competitive_enabled
FROM risk_configs
WHERE active = TRUE
ORDER BY updated_at DESC
LIMIT 1
"""
async def fetch_competitive_enabled(pool: asyncpg.Pool) -> bool | None:
"""Check competitive toggle state from risk_configs table.
Returns True/False if explicitly set, or None if no config exists
(caller should fall back to AggregationConfig default).
"""
row = await pool.fetchrow(_COMPETITIVE_TOGGLE_QUERY)
if row is None or row["competitive_enabled"] is None:
return None
return row["competitive_enabled"].lower() == "true"
# ---------------------------------------------------------------------------
# Fetch competitive signals targeting a ticker within a time window
# ---------------------------------------------------------------------------
_COMPETITIVE_SIGNALS_QUERY = """
SELECT source_document_id, source_ticker, target_ticker, catalyst_type,
pattern_confidence, signal_direction, signal_strength,
relationship_strength, computed_at
FROM competitive_signal_records
WHERE target_ticker = $1
AND computed_at >= $2
AND computed_at <= $3
ORDER BY computed_at DESC
"""
async def fetch_competitive_signals(
pool: asyncpg.Pool,
ticker: str,
window_start: datetime,
window_end: datetime,
) -> list[CompetitiveSignalRecord]:
"""Fetch competitive signal records targeting a ticker in a time range."""
rows = await pool.fetch(
_COMPETITIVE_SIGNALS_QUERY, ticker, window_start, window_end,
)
return [
CompetitiveSignalRecord(
source_document_id=str(row["source_document_id"]),
source_ticker=row["source_ticker"],
target_ticker=row["target_ticker"],
catalyst_type=row["catalyst_type"],
pattern_confidence=float(row["pattern_confidence"]),
signal_direction=row["signal_direction"],
signal_strength=float(row["signal_strength"]),
relationship_strength=float(row["relationship_strength"]),
computed_at=row["computed_at"],
)
for row in rows
]
# ---------------------------------------------------------------------------
# Fetch macro impact records for a ticker within a time window
# ---------------------------------------------------------------------------
_MACRO_IMPACT_QUERY = """
SELECT
mir.event_id,
mir.company_id,
mir.ticker,
mir.macro_impact_score,
mir.impact_direction,
mir.contributing_factors,
mir.confidence,
mir.computed_at,
ge.source_document_id,
d.published_at AS event_published_at
FROM macro_impact_records mir
JOIN global_events ge ON ge.id = mir.event_id
JOIN documents d ON d.id = ge.source_document_id
WHERE mir.ticker = $1
AND mir.computed_at >= $2
AND mir.computed_at <= $3
ORDER BY mir.computed_at DESC
"""
@dataclass
class MacroImpactRow:
"""Parsed row from the macro impact query."""
event_id: str
company_id: str
ticker: str
macro_impact_score: float
impact_direction: str
contributing_factors: list[str]
confidence: float
computed_at: datetime
source_document_id: str
event_published_at: datetime
def _parse_macro_impact_row(row: Any) -> MacroImpactRow:
"""Convert an asyncpg Record to a MacroImpactRow."""
factors = row["contributing_factors"]
if isinstance(factors, str):
factors = json.loads(factors)
return MacroImpactRow(
event_id=str(row["event_id"]),
company_id=str(row["company_id"]),
ticker=row["ticker"],
macro_impact_score=float(row["macro_impact_score"] or 0.0),
impact_direction=row["impact_direction"] or "neutral",
contributing_factors=factors if isinstance(factors, list) else [],
confidence=float(row["confidence"] or 0.5),
computed_at=row["computed_at"],
source_document_id=str(row["source_document_id"]),
event_published_at=row["event_published_at"],
)
async def fetch_macro_impact_records(
pool: asyncpg.Pool,
ticker: str,
window_start: datetime,
window_end: datetime,
) -> list[MacroImpactRow]:
"""Fetch macro impact records for a ticker in a time range."""
rows = await pool.fetch(_MACRO_IMPACT_QUERY, ticker, window_start, window_end)
return [_parse_macro_impact_row(r) for r in rows]
# ---------------------------------------------------------------------------
# Convert macro impact records to WeightedSignals
# ---------------------------------------------------------------------------
_DIRECTION_TO_SENTIMENT: dict[str, float] = {
"positive": 1.0,
"negative": -1.0,
"mixed": 0.0,
"neutral": 0.0,
}
def build_macro_weighted_signals(
macro_impacts: list[MacroImpactRow],
reference_time: datetime,
window: str,
macro_signal_weight: float = 0.3,
config: ScoringConfig | None = None,
) -> list[WeightedSignal]:
"""Convert macro impact records into WeightedSignal objects.
Uses the same scoring pipeline as company signals:
- document_id = source_document_id (for evidence tracing)
- sentiment_value mapped from impact_direction
- impact_score = macro_impact_score * macro_signal_weight
- recency decay from the global event's publication time
- confidence gating from the macro record's confidence
"""
cfg = config or ScoringConfig()
signals: list[WeightedSignal] = []
for mir in macro_impacts:
sw = compute_signal_weight(
published_at=mir.event_published_at,
reference_time=reference_time,
window=window,
source_credibility=mir.confidence,
novelty_score=0.5,
extraction_confidence=mir.confidence,
config=cfg,
)
sentiment = _DIRECTION_TO_SENTIMENT.get(mir.impact_direction, 0.0)
impact = mir.macro_impact_score * macro_signal_weight
signals.append(
WeightedSignal(
document_id=mir.source_document_id,
weight=sw,
sentiment_value=sentiment,
impact_score=impact,
)
)
return signals
# ---------------------------------------------------------------------------
# Build weighted signals from impact records
# ---------------------------------------------------------------------------
@@ -544,6 +789,61 @@ async def persist_trend_evidence(
return len(rows)
# ---------------------------------------------------------------------------
# Build MacroEventInfo objects for projection computation
# ---------------------------------------------------------------------------
_MACRO_EVENT_INFO_QUERY = """
SELECT
mir.event_id,
mir.macro_impact_score,
mir.impact_direction,
mir.confidence,
ge.estimated_duration,
ge.severity,
d.published_at AS event_published_at
FROM macro_impact_records mir
JOIN global_events ge ON ge.id = mir.event_id
JOIN documents d ON d.id = ge.source_document_id
WHERE mir.ticker = $1
AND mir.computed_at >= $2
AND mir.computed_at <= $3
ORDER BY mir.computed_at DESC
"""
async def _build_macro_event_infos(
pool: asyncpg.Pool,
ticker: str,
window_start: datetime,
reference_time: datetime,
) -> list[MacroEventInfo]:
"""Fetch macro impact records and build MacroEventInfo objects for projection."""
rows = await pool.fetch(
_MACRO_EVENT_INFO_QUERY, ticker, window_start, reference_time,
)
infos: list[MacroEventInfo] = []
for row in rows:
published_at = row["event_published_at"]
age_hours = 0.0
if published_at:
age_hours = max(
(reference_time - published_at).total_seconds() / 3600.0, 0.0,
)
infos.append(
MacroEventInfo(
event_id=str(row["event_id"]),
macro_impact_score=float(row["macro_impact_score"] or 0.0),
impact_direction=row["impact_direction"] or "neutral",
confidence=float(row["confidence"] or 0.5),
estimated_duration=row["estimated_duration"] or "short_term",
severity=row["severity"] or "low",
event_age_hours=age_hours,
)
)
return infos
# ---------------------------------------------------------------------------
# Main aggregation entry point for a single ticker + window
# ---------------------------------------------------------------------------
@@ -563,8 +863,10 @@ async def aggregate_company_window(
2. Fetch document impact records from PostgreSQL.
3. Fetch market context for the ticker.
4. Build weighted signals using the scoring module.
5. Assemble the TrendSummary.
6. Persist to trend_windows table.
5. Check macro toggle and fetch/merge macro signals if enabled.
6. Check competitive toggle and fetch/merge pattern/competitive signals if enabled.
7. Assemble the TrendSummary.
8. Persist to trend_windows table.
Returns the assembled TrendSummary.
"""
@@ -589,7 +891,83 @@ async def aggregate_company_window(
impacts, reference_time, window, market_ctx, scoring_cfg,
)
# 4. Assemble trend summary with evidence details
# 4. Check macro toggle and merge macro signals
# (Requirement 11.2, 11.3, 11.4): Toggle state is read from the DB on
# every aggregation cycle. When disabled, macro signals are skipped but
# ingestion/classification continue independently — so when re-enabled,
# the most recent classifications (including those ingested while disabled)
# are immediately available for impact computation.
macro_enabled = cfg.macro_enabled
db_toggle = await fetch_macro_enabled(pool)
if db_toggle is not None:
macro_enabled = db_toggle
if macro_enabled:
macro_impacts = await fetch_macro_impact_records(
pool, ticker, window_start, reference_time,
)
if macro_impacts:
macro_signals = build_macro_weighted_signals(
macro_impacts,
reference_time,
window,
macro_signal_weight=cfg.macro_signal_weight,
config=scoring_cfg,
)
signals = signals + macro_signals
logger.info(
"Merged %d macro signals for %s/%s",
len(macro_signals), ticker, window,
)
# 5. Check competitive toggle and merge pattern/competitive signals
# (Requirements 5.1-5.6): Same toggle pattern as macro layer. When
# disabled, pattern mining remains queryable but aggregation skips
# competitive signals — no degradation of existing behavior.
competitive_enabled = cfg.competitive_enabled
db_competitive_toggle = await fetch_competitive_enabled(pool)
if db_competitive_toggle is not None:
competitive_enabled = db_competitive_toggle
if competitive_enabled:
try:
# Get unique catalyst types from the impact records
catalyst_types = {imp.catalyst_type for imp in impacts}
# Query self-company historical patterns for each catalyst type
all_patterns = []
for cat_type in catalyst_types:
patterns = await find_self_patterns(pool, ticker, cat_type)
all_patterns.extend(patterns)
# Fetch competitive signals targeting this ticker
comp_signals = await fetch_competitive_signals(
pool, ticker, window_start, reference_time,
)
# Convert to WeightedSignal objects
if all_patterns or comp_signals:
pattern_weighted = build_pattern_weighted_signals(
patterns=all_patterns,
competitive_signals=comp_signals,
reference_time=reference_time,
window=window,
)
signals = signals + pattern_weighted
logger.info(
"Merged %d pattern/competitive signals for %s/%s "
"(patterns=%d, competitive=%d)",
len(pattern_weighted), ticker, window,
len(all_patterns), len(comp_signals),
)
except Exception:
logger.exception(
"Failed to fetch pattern/competitive signals for %s/%s"
"continuing with company+macro signals only",
ticker, window,
)
# 6. Assemble trend summary with evidence details
assembled = assemble_trend_with_evidence(
ticker=ticker,
window=window,
@@ -601,10 +979,10 @@ async def aggregate_company_window(
)
summary = assembled.summary
# 5. Persist trend window
# 7. Persist trend window
trend_id = await persist_trend_summary(pool, summary)
# 6. Persist evidence mappings
# 8. Persist evidence mappings
evidence_count = await persist_trend_evidence(
pool, trend_id,
assembled.supporting_evidence,
@@ -617,6 +995,33 @@ async def aggregate_company_window(
summary.trend_strength, summary.confidence, len(signals), evidence_count,
)
# 9. Compute and persist trend projection
try:
macro_event_infos: list[MacroEventInfo] = []
if macro_enabled:
macro_event_infos = await _build_macro_event_infos(
pool, ticker, window_start, reference_time,
)
projection = compute_projection(
summary=summary,
macro_events=macro_event_infos if macro_event_infos else None,
macro_enabled=macro_enabled,
upcoming_catalysts=summary.dominant_catalysts[:3] if summary.dominant_catalysts else None,
)
await persist_trend_projection(pool, trend_id, projection)
logger.info(
"Persisted projection for %s/%s: direction=%s strength=%.3f confidence=%.3f diverges=%s",
ticker, window, projection.projected_direction,
projection.projected_strength, projection.projected_confidence,
projection.diverges_from_current,
)
except Exception:
logger.exception(
"Failed to compute/persist projection for trend %s (%s/%s) — continuing",
trend_id, ticker, window,
)
# Prometheus metrics
AGGREGATION_WINDOWS_COMPUTED.labels(window=window).inc()
AGGREGATION_SIGNALS_PROCESSED.labels(window=window).inc(len(signals))
+597 -1
View File
@@ -28,7 +28,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from services.extractor.metrics import get_model_performance_summary
from services.shared.audit import get_entity_audit_trail, get_order_audit_trail
from services.shared.audit import get_entity_audit_trail, get_order_audit_trail, record_audit_event
from services.shared.config import load_config
from services.shared.db import get_pg_pool
from services.shared.logging import new_trace_id, set_trace_context, setup_logging
@@ -376,6 +376,24 @@ async def list_trends(
):
d[jsonb_field] = _parse_jsonb(d.get(jsonb_field))
results.append(d)
# Include projection data for each trend (Requirement 12.10)
if results:
trend_ids = [r["id"] for r in rows]
proj_rows = await pool.fetch(
"""SELECT DISTINCT ON (trend_window_id)
trend_window_id, projected_direction, projected_strength,
projected_confidence, projection_horizon,
macro_contribution_pct, diverges_from_current
FROM trend_projections
WHERE trend_window_id = ANY($1::uuid[])
ORDER BY trend_window_id, computed_at DESC""",
trend_ids,
)
proj_map = {str(p["trend_window_id"]): _row_to_dict(p) for p in proj_rows}
for d in results:
d["projection"] = proj_map.get(d["id"])
return results
@@ -1687,3 +1705,581 @@ async def delete_saved_query(query_id: str):
if result == "DELETE 0":
raise HTTPException(404, "Query not found")
return {"status": "deleted"}
# ---------------------------------------------------------------------------
# Admin: Macro Signal Layer Toggle (Requirement 11.1, 11.5, 11.7)
# ---------------------------------------------------------------------------
class MacroToggleBody(BaseModel):
enabled: bool
operator: str = "operator"
@app.get("/api/admin/macro/status")
async def get_macro_status():
"""Return the current macro signal layer enabled/disabled state.
Reads from the active risk_configs row's JSONB config field.
Requirements: 11.1, 11.5
"""
row = await pool.fetchrow(
"""SELECT config->>'macro_enabled' AS macro_enabled
FROM risk_configs
WHERE active = TRUE
ORDER BY updated_at DESC
LIMIT 1""",
)
if row is None or row["macro_enabled"] is None:
return {"macro_enabled": True, "source": "default"}
return {
"macro_enabled": row["macro_enabled"].lower() == "true",
"source": "risk_configs",
}
@app.put("/api/admin/macro/toggle")
async def toggle_macro_layer(body: MacroToggleBody):
"""Toggle the macro signal layer on or off.
Persists the new state into the active risk_configs row's JSONB config
and records an audit event with previous state, new state, and operator.
The toggle state is read from PostgreSQL at the start of each aggregation
cycle (no caching), so changes take effect on the next cycle.
Requirements: 11.1, 11.5, 11.7
"""
# Read current state
current_row = await pool.fetchrow(
"""SELECT id, config->>'macro_enabled' AS macro_enabled
FROM risk_configs
WHERE active = TRUE
ORDER BY updated_at DESC
LIMIT 1""",
)
if current_row is None:
# No active config exists — create one
new_config = json.dumps({"macro_enabled": str(body.enabled).lower()})
current_row = await pool.fetchrow(
"""INSERT INTO risk_configs (name, trading_mode, config, active)
VALUES ('default', 'paper', $1::jsonb, TRUE)
RETURNING id, config->>'macro_enabled' AS macro_enabled""",
new_config,
)
previous_enabled = True # default was enabled
else:
prev_val = current_row["macro_enabled"]
previous_enabled = prev_val.lower() == "true" if prev_val else True
config_id = str(current_row["id"])
# Update the config JSONB to set macro_enabled
await pool.execute(
"""UPDATE risk_configs
SET config = config || $2::jsonb, updated_at = NOW()
WHERE id = $1""",
current_row["id"],
json.dumps({"macro_enabled": str(body.enabled).lower()}),
)
# Record audit event (Requirement 11.7)
await record_audit_event(
pool,
event_type="macro.layer_toggled",
entity_type="risk_config",
entity_id=config_id,
data={
"previous_enabled": previous_enabled,
"new_enabled": body.enabled,
},
actor=body.operator,
)
return {
"macro_enabled": body.enabled,
"previous_enabled": previous_enabled,
"toggled_by": body.operator,
}
# ---------------------------------------------------------------------------
# Macro Events and Impacts (Requirement 8.1, 8.2, 12.10)
# ---------------------------------------------------------------------------
@app.get("/api/macro/events")
async def list_macro_events(
severity: Optional[str] = None,
region: Optional[str] = None,
sector: Optional[str] = None,
since: Optional[str] = None,
until: Optional[str] = None,
limit: int = Query(default=50, le=200),
offset: int = 0,
):
"""List recent global events with filtering by severity, region, sector, date range.
Requirements: 8.1
"""
conditions: list[str] = []
params: list[Any] = []
idx = 1
if severity:
conditions.append(f"ge.severity = ${idx}")
params.append(severity)
idx += 1
if region:
conditions.append(f"${idx} = ANY(ge.affected_regions)")
params.append(region)
idx += 1
if sector:
conditions.append(f"${idx} = ANY(ge.affected_sectors)")
params.append(sector)
idx += 1
if since:
conditions.append(f"ge.created_at >= ${idx}::timestamptz")
params.append(since)
idx += 1
if until:
conditions.append(f"ge.created_at <= ${idx}::timestamptz")
params.append(until)
idx += 1
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
rows = await pool.fetch(
f"""SELECT ge.id, ge.event_types, ge.severity, ge.affected_regions,
ge.affected_sectors, ge.affected_commodities, ge.summary,
ge.key_facts, ge.estimated_duration, ge.confidence,
ge.source_document_id, ge.created_at
FROM global_events ge
{where}
ORDER BY ge.created_at DESC
LIMIT ${idx} OFFSET ${idx + 1}""",
*params, limit, offset,
)
results = []
for r in rows:
d = _row_to_dict(r)
d["key_facts"] = _parse_jsonb(d.get("key_facts"))
results.append(d)
return results
@app.get("/api/macro/events/{event_id}")
async def get_macro_event(event_id: str):
"""Event detail with list of affected companies and their macro impact scores.
Requirements: 8.2
"""
row = await pool.fetchrow(
"""SELECT id, event_types, severity, affected_regions, affected_sectors,
affected_commodities, summary, key_facts, estimated_duration,
confidence, source_document_id, model_provider, model_name,
prompt_version, schema_version, created_at
FROM global_events WHERE id = $1""",
event_id,
)
if not row:
raise HTTPException(404, "Global event not found")
result = _row_to_dict(row)
result["key_facts"] = _parse_jsonb(result.get("key_facts"))
# Affected companies with macro impact scores
impacts = await pool.fetch(
"""SELECT mir.id, mir.company_id, mir.ticker, mir.macro_impact_score,
mir.impact_direction, mir.contributing_factors, mir.confidence,
mir.computed_at, c.legal_name, c.sector
FROM macro_impact_records mir
JOIN companies c ON c.id = mir.company_id
WHERE mir.event_id = $1
ORDER BY mir.macro_impact_score DESC""",
event_id,
)
impact_list = []
for imp in impacts:
imp_dict = _row_to_dict(imp)
imp_dict["contributing_factors"] = _parse_jsonb(imp_dict.get("contributing_factors"))
impact_list.append(imp_dict)
result["affected_companies"] = impact_list
return result
@app.get("/api/macro/impacts/{ticker}")
async def get_macro_impacts_for_ticker(
ticker: str,
since: Optional[str] = None,
limit: int = Query(default=50, le=200),
offset: int = 0,
):
"""Macro impacts for a specific company.
Requirements: 8.2
"""
conditions = ["mir.ticker = $1"]
params: list[Any] = [ticker.upper()]
idx = 2
if since:
conditions.append(f"mir.computed_at >= ${idx}::timestamptz")
params.append(since)
idx += 1
where = " AND ".join(conditions)
rows = await pool.fetch(
f"""SELECT mir.id, mir.event_id, mir.company_id, mir.ticker,
mir.macro_impact_score, mir.impact_direction,
mir.contributing_factors, mir.confidence, mir.computed_at,
ge.summary AS event_summary, ge.severity AS event_severity,
ge.event_types AS event_types, ge.affected_regions
FROM macro_impact_records mir
JOIN global_events ge ON ge.id = mir.event_id
WHERE {where}
ORDER BY mir.computed_at DESC
LIMIT ${idx} OFFSET ${idx + 1}""",
*params, limit, offset,
)
results = []
for r in rows:
d = _row_to_dict(r)
d["contributing_factors"] = _parse_jsonb(d.get("contributing_factors"))
results.append(d)
return results
# ---------------------------------------------------------------------------
# Trend Projections (Requirement 12.10)
# ---------------------------------------------------------------------------
@app.get("/api/trends/{trend_id}/projection")
async def get_trend_projection(trend_id: str):
"""Trend projection for a specific trend window.
Requirements: 12.10
"""
# Verify trend exists
trend_row = await pool.fetchrow(
"SELECT id FROM trend_windows WHERE id = $1", trend_id,
)
if not trend_row:
raise HTTPException(404, "Trend not found")
row = await pool.fetchrow(
"""SELECT id, trend_window_id, projected_direction, projected_strength,
projected_confidence, projection_horizon, driving_factors,
macro_contribution_pct, diverges_from_current, computed_at
FROM trend_projections WHERE trend_window_id = $1
ORDER BY computed_at DESC LIMIT 1""",
trend_id,
)
if not row:
return {"trend_window_id": trend_id, "projection": None}
d = _row_to_dict(row)
d["driving_factors"] = _parse_jsonb(d.get("driving_factors"))
return d
# ---------------------------------------------------------------------------
# Competitive Layer Toggle (Requirements 6.1, 6.2, 6.3, 6.4, 6.5, 6.7)
# ---------------------------------------------------------------------------
class CompetitiveToggleBody(BaseModel):
enabled: bool
operator: str = "operator"
@app.get("/api/admin/competitive/status")
async def get_competitive_status():
"""Return the current competitive signal layer enabled/disabled state.
Reads from the active risk_configs row's JSONB config field.
Requirements: 6.1, 6.5
"""
row = await pool.fetchrow(
"""SELECT config->>'competitive_enabled' AS competitive_enabled
FROM risk_configs
WHERE active = TRUE
ORDER BY updated_at DESC
LIMIT 1""",
)
if row is None or row["competitive_enabled"] is None:
return {"competitive_enabled": True, "source": "default"}
return {
"competitive_enabled": row["competitive_enabled"].lower() == "true",
"source": "risk_configs",
}
@app.put("/api/admin/competitive/toggle")
async def toggle_competitive_layer(body: CompetitiveToggleBody):
"""Toggle the competitive signal layer on or off.
Persists the new state into the active risk_configs row's JSONB config
and records an audit event with previous state, new state, and operator.
Toggle state is read from PostgreSQL at the start of each aggregation
cycle (no caching), so changes take effect on the next cycle.
When disabled, pattern mining remains queryable via API but signal
propagation is skipped during aggregation. When re-enabled, the engine
resumes computing signals using latest historical data including
intelligence ingested while disabled.
Requirements: 6.1, 6.2, 6.3, 6.4, 6.5, 6.7
"""
# Read current state
current_row = await pool.fetchrow(
"""SELECT id, config->>'competitive_enabled' AS competitive_enabled
FROM risk_configs
WHERE active = TRUE
ORDER BY updated_at DESC
LIMIT 1""",
)
if current_row is None:
# No active config exists — create one
new_config = json.dumps({"competitive_enabled": str(body.enabled).lower()})
current_row = await pool.fetchrow(
"""INSERT INTO risk_configs (name, trading_mode, config, active)
VALUES ('default', 'paper', $1::jsonb, TRUE)
RETURNING id, config->>'competitive_enabled' AS competitive_enabled""",
new_config,
)
previous_enabled = True # default was enabled
else:
prev_val = current_row["competitive_enabled"]
previous_enabled = prev_val.lower() == "true" if prev_val else True
config_id = str(current_row["id"])
# Update the config JSONB to set competitive_enabled
await pool.execute(
"""UPDATE risk_configs
SET config = config || $2::jsonb, updated_at = NOW()
WHERE id = $1""",
current_row["id"],
json.dumps({"competitive_enabled": str(body.enabled).lower()}),
)
# Record audit event (Requirement 6.7)
await record_audit_event(
pool,
event_type="competitive.layer_toggled",
entity_type="risk_config",
entity_id=config_id,
data={
"previous_enabled": previous_enabled,
"new_enabled": body.enabled,
},
actor=body.operator,
)
return {
"competitive_enabled": body.enabled,
"previous_enabled": previous_enabled,
"toggled_by": body.operator,
}
# ---------------------------------------------------------------------------
# Historical Pattern & Competitive Signal Query Endpoints
# (Requirements 10.1, 10.2, 10.3, 10.4, 11.4, 11.6)
# ---------------------------------------------------------------------------
from dataclasses import asdict
from services.aggregation.pattern_matcher import (
find_cross_company_patterns,
find_self_patterns,
)
from services.shared.schemas import MAJOR_DECISION_CATALYSTS
def _pattern_to_dict(p) -> dict[str, Any]:
"""Convert a HistoricalPattern dataclass to a JSON-safe dict."""
d = asdict(p)
for key, val in d.items():
if isinstance(val, datetime):
d[key] = val.isoformat()
return d
@app.get("/api/patterns/{ticker}")
async def get_patterns_for_ticker(
ticker: str,
catalyst_type: Optional[str] = None,
time_horizon: Optional[str] = None,
):
"""Historical patterns for a company.
Filterable by catalyst_type and time_horizon.
Returns sample_count, outcome distribution, pattern_confidence,
and date range for each pattern.
Requirements: 10.1, 10.3
"""
horizons = [time_horizon] if time_horizon else None
if catalyst_type:
patterns = await find_self_patterns(pool, ticker, catalyst_type, horizons=horizons)
else:
# Query across all catalyst types present in the company's history
rows = await pool.fetch(
"""SELECT DISTINCT di.catalyst_type
FROM document_impact_records dir
JOIN document_intelligence di ON di.document_id = dir.document_id
JOIN documents d ON d.id = dir.document_id
WHERE dir.ticker = $1
AND di.validation_status = 'valid'
AND d.status != 'rejected'
AND di.catalyst_type IS NOT NULL""",
ticker,
)
patterns = []
for row in rows:
ct = row["catalyst_type"]
patterns.extend(await find_self_patterns(pool, ticker, ct, horizons=horizons))
return {
"ticker": ticker,
"patterns": [_pattern_to_dict(p) for p in patterns],
"count": len(patterns),
}
@app.get("/api/patterns/{ticker}/competitors")
async def get_competitor_patterns(
ticker: str,
catalyst_type: Optional[str] = None,
time_horizon: Optional[str] = None,
):
"""Cross-company patterns showing how this company's catalysts affected competitors.
Requirements: 10.2, 10.3
"""
horizons = [time_horizon] if time_horizon else None
# Find active competitors for this ticker
comp_rows = await pool.fetch(
"""SELECT DISTINCT
CASE WHEN ca.ticker = $1 THEN cb.ticker ELSE ca.ticker END AS competitor_ticker
FROM competitor_relationships cr
JOIN companies ca ON ca.id = cr.company_a_id
JOIN companies cb ON cb.id = cr.company_b_id
WHERE cr.active = TRUE
AND (ca.ticker = $1 OR cb.ticker = $1)""",
ticker,
)
# Determine catalyst types to query
if catalyst_type:
catalyst_types = [catalyst_type]
else:
ct_rows = await pool.fetch(
"""SELECT DISTINCT di.catalyst_type
FROM document_impact_records dir
JOIN document_intelligence di ON di.document_id = dir.document_id
JOIN documents d ON d.id = dir.document_id
WHERE dir.ticker = $1
AND di.validation_status = 'valid'
AND d.status != 'rejected'
AND di.catalyst_type IS NOT NULL""",
ticker,
)
catalyst_types = [r["catalyst_type"] for r in ct_rows]
patterns = []
for comp_row in comp_rows:
comp_ticker = comp_row["competitor_ticker"]
for ct in catalyst_types:
cross = await find_cross_company_patterns(
pool, ticker, comp_ticker, ct, horizons=horizons,
)
patterns.extend(cross)
return {
"ticker": ticker,
"cross_company_patterns": [_pattern_to_dict(p) for p in patterns],
"count": len(patterns),
}
@app.get("/api/patterns/{ticker}/competitive-signals")
async def get_competitive_signals(ticker: str):
"""Recent competitive signals targeting this company.
Requirements: 10.4
"""
rows = await pool.fetch(
"""SELECT id, source_document_id, source_ticker, target_ticker,
catalyst_type, pattern_confidence, signal_direction,
signal_strength, relationship_strength, computed_at
FROM competitive_signal_records
WHERE target_ticker = $1
ORDER BY computed_at DESC
LIMIT 100""",
ticker,
)
return {
"ticker": ticker,
"competitive_signals": [_row_to_dict(r) for r in rows],
"count": len(rows),
}
@app.get("/api/patterns/{ticker}/decisions")
async def get_decision_history(
ticker: str,
time_horizon: Optional[str] = None,
):
"""Major corporate decision history with trend outcomes and pattern statistics.
Queries document_impact_records filtered by MAJOR_DECISION_CATALYSTS,
joined with trend_windows for outcome data.
Requirements: 11.4, 11.6
"""
major_types = list(MAJOR_DECISION_CATALYSTS)
horizons = [time_horizon] if time_horizon else None
# Fetch major decision records for this ticker
rows = await pool.fetch(
"""SELECT dir.id, dir.document_id, dir.ticker,
di.catalyst_type, di.summary,
dir.impact_score, dir.created_at,
d.published_at
FROM document_impact_records dir
JOIN document_intelligence di ON di.document_id = dir.document_id
JOIN documents d ON d.id = dir.document_id
WHERE dir.ticker = $1
AND di.validation_status = 'valid'
AND d.status != 'rejected'
AND di.catalyst_type = ANY($2)
ORDER BY dir.created_at DESC
LIMIT 50""",
ticker,
major_types,
)
decisions = []
for row in rows:
decision = _row_to_dict(row)
# Fetch pattern statistics for this catalyst type
ct = row["catalyst_type"]
patterns = await find_self_patterns(pool, ticker, ct, horizons=horizons)
decision["pattern_statistics"] = [_pattern_to_dict(p) for p in patterns]
decisions.append(decision)
return {
"ticker": ticker,
"decisions": decisions,
"count": len(decisions),
}
+549
View File
@@ -0,0 +1,549 @@
"""Event classifier module for macro news articles.
Classifies global/geopolitical news articles into structured GlobalEvent
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
existing OllamaClient for inference and retry logic.
Persists classification prompts, raw outputs, and final events to MinIO
and PostgreSQL for audit and downstream interpolation.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
import asyncpg
from minio import Minio
from services.shared.schemas import (
EstimatedDuration,
ImpactType,
ModelMetadata,
SeverityLevel,
)
from services.shared.storage import upload_artifact
logger = logging.getLogger("event_classifier")
PROMPT_VERSION = "event-classification-v1"
SCHEMA_VERSION = "1.0.0"
# Valid enum value sets for normalization
_VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType)
_VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel)
_VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration)
# ---------------------------------------------------------------------------
# GlobalEvent dataclass
# ---------------------------------------------------------------------------
@dataclass
class GlobalEvent:
"""Structured classification of a macro news event.
Produced by the event classifier from Ollama structured output.
"""
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
event_types: list[str] = field(default_factory=list)
severity: str = "low"
affected_regions: list[str] = field(default_factory=list)
affected_sectors: list[str] = field(default_factory=list)
affected_commodities: list[str] = field(default_factory=list)
summary: str = ""
key_facts: list[str] = field(default_factory=list)
estimated_duration: str = "short_term"
confidence: float = 0.5
source_document_id: str = ""
model_metadata: ModelMetadata = field(default_factory=ModelMetadata)
# ---------------------------------------------------------------------------
# JSON schema for Ollama structured output
# ---------------------------------------------------------------------------
class _EventClassificationResult:
"""Schema definition for the Ollama event classification response.
Not a Pydantic model — we build the JSON schema dict directly to keep
it self-contained and Ollama-friendly (no $refs).
"""
pass
def get_event_json_schema() -> dict[str, Any]:
"""Return the JSON schema for Ollama structured event classification output.
The schema forces the model to produce all required fields explicitly.
"""
return {
"type": "object",
"required": [
"event_types",
"severity",
"affected_regions",
"affected_sectors",
"affected_commodities",
"summary",
"key_facts",
"estimated_duration",
"confidence",
],
"properties": {
"event_types": {
"type": "array",
"items": {
"type": "string",
"enum": sorted(_VALID_IMPACT_TYPES),
},
"description": (
"One or more impact types this event represents. "
"Include ALL applicable types — do not collapse to a single category."
),
},
"severity": {
"type": "string",
"enum": sorted(_VALID_SEVERITY_LEVELS),
"description": "Overall severity of the event: low, moderate, high, or critical.",
},
"affected_regions": {
"type": "array",
"items": {"type": "string"},
"description": (
"ISO 3166-1 alpha-2 country codes or region names affected. "
"Use standard codes like US, CN, EU, GB, JP. "
"Only include regions explicitly mentioned or clearly implied."
),
},
"affected_sectors": {
"type": "array",
"items": {"type": "string"},
"description": (
"GICS sector identifiers or sector names affected. "
"Examples: Energy, Materials, Industrials, Consumer Discretionary, "
"Consumer Staples, Health Care, Financials, Information Technology, "
"Communication Services, Utilities, Real Estate."
),
},
"affected_commodities": {
"type": "array",
"items": {"type": "string"},
"description": (
"Commodity identifiers affected, if applicable. "
"Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, "
"semiconductors. Empty list if no commodities are directly affected."
),
},
"summary": {
"type": "string",
"description": "A concise 1-3 sentence summary of the event and its market implications.",
},
"key_facts": {
"type": "array",
"items": {"type": "string"},
"description": (
"Key facts explicitly stated in the article. "
"Do NOT infer, speculate, or fabricate facts. "
"Each fact must be directly supported by the text."
),
},
"estimated_duration": {
"type": "string",
"enum": sorted(_VALID_DURATIONS),
"description": (
"Expected duration of market impact: "
"short_term (days to weeks), medium_term (weeks to months), "
"long_term (months to years)."
),
},
"confidence": {
"type": "number",
"minimum": 0.0,
"maximum": 1.0,
"description": (
"Your confidence in this classification. "
"Lower if the article is ambiguous, speculative, or lacks concrete details."
),
},
},
"additionalProperties": False,
}
# ---------------------------------------------------------------------------
# Prompt builder
# ---------------------------------------------------------------------------
_SYSTEM_PROMPT = """\
You classify global news articles into structured macro event intelligence. \
Read the article carefully and extract the event classification. \
Return ONLY valid JSON matching the schema. No commentary, no markdown, no explanation."""
_ANTI_HALLUCINATION_RULES = """\
CRITICAL RULES — read carefully:
1. Only extract information EXPLICITLY stated in the article text.
2. Do NOT infer, speculate, or fabricate facts, regions, sectors, or commodities.
3. If the article mentions multiple distinct impact types, include ALL of them in event_types.
4. For affected_regions, only include regions explicitly mentioned or clearly implied by the event.
5. For affected_sectors, only include sectors with a clear causal link to the event.
6. For affected_commodities, only include commodities directly referenced or obviously impacted.
7. For key_facts, each fact must be directly supported by a specific passage in the text.
8. If the article is vague or speculative, set confidence LOW (below 0.4).
9. Do NOT treat journalist speculation or opinion as confirmed fact.
10. Distinguish between announced policy and proposed/rumored policy."""
def build_event_classification_prompt(text: str) -> dict[str, str]:
"""Build system and user prompts for Ollama event classification.
Args:
text: Normalized text content of the macro news article.
Returns:
Dict with 'system' and 'user' prompt strings.
"""
user_prompt = f"""\
Classify this global news article as a macro event. Fill every field.
{_ANTI_HALLUCINATION_RULES}
Classify the event by:
- event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \
regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk)
- severity: low, moderate, high, or critical
- affected_regions: ISO country codes or region names
- affected_sectors: GICS sector names
- affected_commodities: commodity identifiers (empty list if none)
- summary: 1-3 sentence summary of the event and market implications
- key_facts: facts explicitly stated in the text (NO fabrication)
- estimated_duration: short_term, medium_term, or long_term
- confidence: 0.0-1.0 your confidence in this classification
--- ARTICLE TEXT ---
{text}
--- END ARTICLE TEXT ---"""
return {
"system": _SYSTEM_PROMPT,
"user": user_prompt,
}
# ---------------------------------------------------------------------------
# Classification response parsing and normalization
# ---------------------------------------------------------------------------
def _normalize_event_types(raw: list[Any]) -> list[str]:
"""Normalize and filter event_types to valid ImpactType values."""
result = []
for item in raw:
val = str(item).lower().strip()
if val in _VALID_IMPACT_TYPES:
result.append(val)
return result if result else ["geopolitical_risk"]
def _normalize_severity(raw: str) -> str:
"""Normalize severity to a valid SeverityLevel value."""
val = str(raw).lower().strip()
return val if val in _VALID_SEVERITY_LEVELS else "low"
def _normalize_duration(raw: str) -> str:
"""Normalize estimated_duration to a valid EstimatedDuration value."""
val = str(raw).lower().strip()
return val if val in _VALID_DURATIONS else "short_term"
def _parse_classification_response(
raw_json: str,
document_id: str,
model_name: str,
) -> GlobalEvent:
"""Parse raw Ollama JSON output into a GlobalEvent.
Normalizes enum values and clamps numeric fields.
"""
data = json.loads(raw_json)
confidence = data.get("confidence", 0.5)
if isinstance(confidence, (int, float)):
confidence = max(0.0, min(1.0, float(confidence)))
else:
confidence = 0.5
return GlobalEvent(
event_id=str(uuid.uuid4()),
event_types=_normalize_event_types(data.get("event_types", [])),
severity=_normalize_severity(data.get("severity", "low")),
affected_regions=[str(r) for r in data.get("affected_regions", [])],
affected_sectors=[str(s) for s in data.get("affected_sectors", [])],
affected_commodities=[str(c) for c in data.get("affected_commodities", [])],
summary=str(data.get("summary", "")),
key_facts=[str(f) for f in data.get("key_facts", [])],
estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")),
confidence=confidence,
source_document_id=document_id,
model_metadata=ModelMetadata(
provider="ollama",
model_name=model_name,
prompt_version=PROMPT_VERSION,
schema_version=SCHEMA_VERSION,
),
)
# ---------------------------------------------------------------------------
# MinIO persistence helpers
# ---------------------------------------------------------------------------
def _upload_classification_prompt(
minio_client: Minio,
document_id: str,
prompt_data: dict[str, str],
model_name: str,
timestamp: datetime | None = None,
) -> str:
"""Upload classification prompt and metadata to stonks-llm-prompts."""
ts = timestamp or datetime.now(timezone.utc)
payload = json.dumps({
"prompt_version": PROMPT_VERSION,
"schema_version": SCHEMA_VERSION,
"model": model_name,
"system_prompt": prompt_data["system"],
"user_prompt": prompt_data["user"],
"json_schema": get_event_json_schema(),
}, indent=2).encode()
path = (
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/prompt.json"
)
return upload_artifact(
minio_client, "stonks-llm-prompts", path, payload,
content_type="application/json",
)
def _upload_classification_result(
minio_client: Minio,
document_id: str,
raw_output: str,
event: GlobalEvent | None,
success: bool,
error: str | None,
timestamp: datetime | None = None,
) -> str:
"""Upload raw classification output to stonks-llm-results."""
ts = timestamp or datetime.now(timezone.utc)
payload = json.dumps({
"document_id": document_id,
"success": success,
"error": error,
"raw_output": raw_output,
"parsed_event": {
"event_id": event.event_id,
"event_types": event.event_types,
"severity": event.severity,
"affected_regions": event.affected_regions,
"affected_sectors": event.affected_sectors,
"affected_commodities": event.affected_commodities,
"summary": event.summary,
"key_facts": event.key_facts,
"estimated_duration": event.estimated_duration,
"confidence": event.confidence,
} if event else None,
}, indent=2).encode()
path = (
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/result.json"
)
return upload_artifact(
minio_client, "stonks-llm-results", path, payload,
content_type="application/json",
)
# ---------------------------------------------------------------------------
# PostgreSQL persistence
# ---------------------------------------------------------------------------
async def persist_global_event(
pool: asyncpg.Pool,
event: GlobalEvent,
) -> str:
"""Persist a GlobalEvent record to the global_events PostgreSQL table.
Returns the event row UUID.
"""
row_id = await pool.fetchval(
"""INSERT INTO global_events
(id, event_types, severity, affected_regions, affected_sectors,
affected_commodities, summary, key_facts, estimated_duration,
confidence, source_document_id, model_provider, model_name,
prompt_version, schema_version)
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10,
$11::uuid, $12, $13, $14, $15)
RETURNING id""",
event.event_id,
event.event_types,
event.severity,
event.affected_regions,
event.affected_sectors,
event.affected_commodities,
event.summary,
json.dumps(event.key_facts),
event.estimated_duration,
event.confidence,
event.source_document_id,
event.model_metadata.provider,
event.model_metadata.model_name,
event.model_metadata.prompt_version,
event.model_metadata.schema_version,
)
logger.info(
"Persisted global event %s for doc %s (severity=%s, types=%s)",
row_id, event.source_document_id, event.severity, event.event_types,
)
return str(row_id)
# ---------------------------------------------------------------------------
# Main classification function
# ---------------------------------------------------------------------------
async def classify_global_event(
normalized_text: str,
document_id: str,
ollama_client: Any,
*,
pool: asyncpg.Pool | None = None,
minio_client: Minio | None = None,
) -> GlobalEvent:
"""Classify a macro news article into a GlobalEvent using Ollama.
Uses the existing OllamaClient's streaming infrastructure with a
dedicated event classification prompt and JSON schema. Follows the
same retry policy as document extraction.
Persists prompt, raw output, and final event to MinIO and PostgreSQL
when the respective clients are provided.
Args:
normalized_text: Cleaned text content of the macro article.
document_id: UUID of the source document.
ollama_client: An OllamaClient instance (from services.extractor.client).
pool: Optional asyncpg pool for PostgreSQL persistence.
minio_client: Optional MinIO client for artifact persistence.
Returns:
A GlobalEvent with the classification result.
Raises:
ValueError: If classification fails after all retries.
"""
ts = datetime.now(timezone.utc)
prompts = build_event_classification_prompt(normalized_text)
json_schema = get_event_json_schema()
model_name = ollama_client._config.model
# Persist prompt to MinIO
prompt_ref = None
if minio_client:
try:
prompt_ref = _upload_classification_prompt(
minio_client, document_id, prompts, model_name, timestamp=ts,
)
except Exception:
logger.exception("Failed to upload classification prompt for doc %s", document_id)
# Call Ollama using the client's internal _call_ollama method
# We reuse the retry logic pattern from OllamaClient.extract()
max_retries = ollama_client._max_retries
last_error: str | None = None
raw_output = ""
for attempt_num in range(max_retries + 1):
attempt = await ollama_client._call_ollama(prompts, json_schema)
raw_output = attempt.raw_output
if attempt.error is None and raw_output:
# Try to parse the response
try:
event = _parse_classification_response(
raw_output, document_id, model_name,
)
# Persist result to MinIO
if minio_client:
try:
_upload_classification_result(
minio_client, document_id, raw_output,
event, success=True, error=None, timestamp=ts,
)
except Exception:
logger.exception(
"Failed to upload classification result for doc %s", document_id,
)
# Persist to PostgreSQL
if pool:
try:
await persist_global_event(pool, event)
except Exception:
logger.exception(
"Failed to persist global event for doc %s", document_id,
)
return event
except (json.JSONDecodeError, KeyError, TypeError) as exc:
last_error = f"parse_error: {exc}"
logger.warning(
"Classification parse error for doc %s attempt %d: %s",
document_id, attempt_num + 1, exc,
)
else:
last_error = attempt.error or "empty_response"
# Retry with backoff
if attempt_num < max_retries:
delay = ollama_client._base_delay * (
ollama_client._backoff_multiplier ** attempt_num
)
delay = min(delay, ollama_client._max_delay)
logger.warning(
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
)
await asyncio.sleep(delay)
# All retries exhausted — persist failure and raise
if minio_client:
try:
_upload_classification_result(
minio_client, document_id, raw_output,
event=None, success=False, error=last_error, timestamp=ts,
)
except Exception:
logger.exception(
"Failed to upload failed classification result for doc %s", document_id,
)
raise ValueError(
f"Event classification failed for document {document_id} "
f"after {max_retries + 1} attempts: {last_error}"
)
+394
View File
@@ -0,0 +1,394 @@
"""Exposure profile auto-inference from filing extractions.
Infers baseline exposure profiles from company filing extractions when
no manual profile exists. Scans recent filing extractions for geographic
revenue breakdowns, supplier mentions, and commodity references.
Requirements: 9.1, 9.2, 9.3
"""
from __future__ import annotations
import logging
import re
from collections import defaultdict
from services.aggregation.interpolation import build_default_profile
from services.shared.schemas import (
DocumentIntelligence,
ExposureProfileSchema,
MarketPositionTier,
)
logger = logging.getLogger("exposure_inference")
# ---------------------------------------------------------------------------
# Known region patterns for geographic extraction
# ---------------------------------------------------------------------------
_REGION_KEYWORDS: dict[str, str] = {
"united states": "US",
"u.s.": "US",
"us": "US",
"america": "US",
"north america": "US",
"china": "CN",
"chinese": "CN",
"europe": "EU",
"european": "EU",
"eu": "EU",
"japan": "JP",
"japanese": "JP",
"germany": "DE",
"german": "DE",
"united kingdom": "GB",
"uk": "GB",
"britain": "GB",
"british": "GB",
"south korea": "KR",
"korea": "KR",
"india": "IN",
"indian": "IN",
"brazil": "BR",
"brazilian": "BR",
"australia": "AU",
"australian": "AU",
"canada": "CA",
"canadian": "CA",
"taiwan": "TW",
"saudi arabia": "SA",
"russia": "RU",
"russian": "RU",
"mexico": "MX",
"singapore": "SG",
"asia": "CN",
"asia pacific": "CN",
"latin america": "BR",
"middle east": "SA",
}
# ---------------------------------------------------------------------------
# Known commodity patterns
# ---------------------------------------------------------------------------
_COMMODITY_KEYWORDS: dict[str, str] = {
"crude oil": "crude_oil",
"oil": "crude_oil",
"petroleum": "crude_oil",
"natural gas": "natural_gas",
"gas": "natural_gas",
"copper": "copper",
"steel": "steel",
"lithium": "lithium",
"semiconductor": "semiconductors",
"semiconductors": "semiconductors",
"chip": "semiconductors",
"chips": "semiconductors",
"wheat": "wheat",
"corn": "corn",
"gold": "gold",
"aluminum": "aluminum",
"aluminium": "aluminum",
"nickel": "nickel",
"cobalt": "cobalt",
"rare earth": "rare_earth",
}
# Minimum number of filing documents to consider inference meaningful
_MIN_FILINGS_FOR_INFERENCE = 1
# Minimum total mentions to consider a region significant
_MIN_REGION_MENTIONS = 1
# Minimum total mentions to consider a commodity significant
_MIN_COMMODITY_MENTIONS = 1
# ---------------------------------------------------------------------------
# Text scanning helpers
# ---------------------------------------------------------------------------
def _extract_regions_from_text(text: str) -> dict[str, int]:
"""Extract region mentions from text, returning region_code -> count."""
text_lower = text.lower()
region_counts: dict[str, int] = defaultdict(int)
for keyword, code in _REGION_KEYWORDS.items():
# Use word boundary matching for short keywords
if len(keyword) <= 3:
pattern = rf"\b{re.escape(keyword)}\b"
matches = re.findall(pattern, text_lower)
else:
matches = re.findall(re.escape(keyword), text_lower)
if matches:
region_counts[code] += len(matches)
return dict(region_counts)
def _extract_commodities_from_text(text: str) -> dict[str, int]:
"""Extract commodity mentions from text, returning commodity_id -> count."""
text_lower = text.lower()
commodity_counts: dict[str, int] = defaultdict(int)
for keyword, commodity_id in _COMMODITY_KEYWORDS.items():
if len(keyword) <= 4:
pattern = rf"\b{re.escape(keyword)}\b"
matches = re.findall(pattern, text_lower)
else:
matches = re.findall(re.escape(keyword), text_lower)
if matches:
commodity_counts[commodity_id] += len(matches)
return dict(commodity_counts)
def _extract_supply_chain_regions(text: str) -> set[str]:
"""Extract supply chain region mentions from text."""
supply_keywords = [
"supplier", "supply chain", "sourcing", "manufacturing",
"factory", "plant", "warehouse", "distribution",
"import", "export", "procurement",
]
text_lower = text.lower()
regions: set[str] = set()
for keyword in supply_keywords:
if keyword in text_lower:
# Find regions mentioned near supply chain keywords
# Look within a window around each occurrence
for match in re.finditer(re.escape(keyword), text_lower):
start = max(0, match.start() - 200)
end = min(len(text_lower), match.end() + 200)
window = text_lower[start:end]
window_regions = _extract_regions_from_text(window)
regions.update(window_regions.keys())
return regions
# ---------------------------------------------------------------------------
# Revenue mix estimation
# ---------------------------------------------------------------------------
def _estimate_revenue_mix(region_counts: dict[str, int]) -> dict[str, float]:
"""Estimate geographic revenue mix from region mention counts.
Uses mention frequency as a proxy for revenue distribution.
Normalizes to sum to 1.0.
"""
if not region_counts:
return {}
total = sum(region_counts.values())
if total == 0:
return {}
mix = {
region: round(count / total, 4)
for region, count in region_counts.items()
if count >= _MIN_REGION_MENTIONS
}
# Re-normalize after filtering
mix_total = sum(mix.values())
if mix_total > 0 and abs(mix_total - 1.0) > 0.001:
mix = {r: round(v / mix_total, 4) for r, v in mix.items()}
return mix
# ---------------------------------------------------------------------------
# Confidence scoring
# ---------------------------------------------------------------------------
def _compute_inference_confidence(
num_filings: int,
num_regions: int,
num_commodities: int,
total_mentions: int,
) -> float:
"""Compute confidence score for the inferred profile.
Higher confidence when more filings are available and more
geographic/commodity data points are found.
"""
# Base confidence from number of filings (more filings = more reliable)
filing_factor = min(num_filings / 5.0, 1.0) # saturates at 5 filings
# Data richness factor
data_points = num_regions + num_commodities
richness_factor = min(data_points / 8.0, 1.0) # saturates at 8 data points
# Mention volume factor
volume_factor = min(total_mentions / 20.0, 1.0) # saturates at 20 mentions
confidence = 0.4 * filing_factor + 0.35 * richness_factor + 0.25 * volume_factor
return round(max(0.0, min(1.0, confidence)), 4)
# ---------------------------------------------------------------------------
# Main inference function
# ---------------------------------------------------------------------------
def infer_exposure_profile(
document_intelligences: list[DocumentIntelligence],
sector: str,
industry: str,
market_cap_bucket: str,
) -> ExposureProfileSchema:
"""Infer a baseline exposure profile from filing extractions.
Scans recent filing extractions for geographic revenue breakdowns,
supplier mentions, and commodity references. Produces an
ExposureProfile with source='inferred' and a confidence score
reflecting data quality.
Falls back to sector-based default profile when insufficient
filing data is available.
Args:
document_intelligences: List of DocumentIntelligence from recent filings.
sector: Company's GICS sector name.
industry: Company's industry name.
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
Returns:
An ExposureProfileSchema with source='inferred'.
Requirements: 9.1, 9.2, 9.3
"""
# Filter to filing-type documents
filings = [
di for di in document_intelligences
if di.document_type.value in ("filing", "transcript")
]
if len(filings) < _MIN_FILINGS_FOR_INFERENCE:
logger.info(
"Insufficient filing data (%d filings) for inference, "
"falling back to sector-based default profile",
len(filings),
)
return build_default_profile(sector, industry, market_cap_bucket)
# Aggregate region and commodity mentions across all filings
all_region_counts: dict[str, int] = defaultdict(int)
all_commodity_counts: dict[str, int] = defaultdict(int)
all_supply_regions: set[str] = set()
for filing in filings:
# Scan summary text
if filing.summary:
regions = _extract_regions_from_text(filing.summary)
for r, c in regions.items():
all_region_counts[r] += c
commodities = _extract_commodities_from_text(filing.summary)
for com, c in commodities.items():
all_commodity_counts[com] += c
supply_regions = _extract_supply_chain_regions(filing.summary)
all_supply_regions.update(supply_regions)
# Scan company impacts for geographic and commodity mentions
for company in filing.companies:
# Key facts and evidence spans contain geographic details
for text in company.key_facts + company.evidence_spans:
regions = _extract_regions_from_text(text)
for r, c in regions.items():
all_region_counts[r] += c
commodities = _extract_commodities_from_text(text)
for com, c in commodities.items():
all_commodity_counts[com] += c
supply_regions = _extract_supply_chain_regions(text)
all_supply_regions.update(supply_regions)
# Scan macro themes for commodity/region hints
for theme in filing.macro_themes:
regions = _extract_regions_from_text(theme)
for r, c in regions.items():
all_region_counts[r] += c
commodities = _extract_commodities_from_text(theme)
for com, c in commodities.items():
all_commodity_counts[com] += c
# Check if we have enough data to infer
total_mentions = sum(all_region_counts.values()) + sum(all_commodity_counts.values())
has_regions = len(all_region_counts) > 0
has_commodities = len(all_commodity_counts) > 0
if not has_regions and not has_commodities:
logger.info(
"No geographic or commodity data found in %d filings, "
"falling back to sector-based default profile",
len(filings),
)
return build_default_profile(sector, industry, market_cap_bucket)
# Build the inferred profile
geographic_revenue_mix = _estimate_revenue_mix(dict(all_region_counts))
# Filter commodities by minimum mentions
key_commodities = [
com for com, count in all_commodity_counts.items()
if count >= _MIN_COMMODITY_MENTIONS
]
# Supply chain regions: combine extracted supply regions with geo regions
supply_chain_regions = list(all_supply_regions | set(geographic_revenue_mix.keys()))
# Market position tier from market cap bucket
from services.aggregation.interpolation import _CAP_TO_TIER
tier_value = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
# Regulatory jurisdictions: top regions by revenue
sorted_regions = sorted(
geographic_revenue_mix.items(), key=lambda x: x[1], reverse=True,
)
regulatory_jurisdictions = [r for r, _ in sorted_regions[:3]]
# Export dependency: fraction of revenue outside the top region
if geographic_revenue_mix:
top_region_pct = max(geographic_revenue_mix.values())
export_pct = round(1.0 - top_region_pct, 4)
else:
export_pct = 0.0
# Confidence score
confidence = _compute_inference_confidence(
num_filings=len(filings),
num_regions=len(all_region_counts),
num_commodities=len(all_commodity_counts),
total_mentions=total_mentions,
)
profile = ExposureProfileSchema(
company_id="",
geographic_revenue_mix=geographic_revenue_mix,
supply_chain_regions=supply_chain_regions,
key_input_commodities=key_commodities,
regulatory_jurisdictions=regulatory_jurisdictions,
market_position_tier=MarketPositionTier(tier_value),
export_dependency_pct=max(0.0, min(1.0, export_pct)),
source="inferred",
confidence=confidence,
version=1,
)
logger.info(
"Inferred exposure profile: regions=%d, commodities=%d, "
"supply_chain=%d, confidence=%.3f",
len(geographic_revenue_mix),
len(key_commodities),
len(supply_chain_regions),
confidence,
)
return profile
+234 -4
View File
@@ -9,13 +9,21 @@ import asyncpg
import redis.asyncio as aioredis
from minio import Minio
from services.aggregation.interpolation import (
build_default_profile,
compute_macro_impact_with_sector,
filter_low_confidence_events,
persist_macro_impact_records,
)
from services.extractor.client import OllamaClient
from services.extractor.event_classifier import classify_global_event
from services.extractor.worker import persist_extraction
from services.shared.config import load_config
from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import (
QUEUE_AGGREGATION,
QUEUE_EXTRACTION,
QUEUE_MACRO_CLASSIFICATION,
queue_key,
)
@@ -28,6 +36,198 @@ async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
return {row["ticker"]: str(row["id"]) for row in rows}
async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None:
"""Fetch the document_type for a document."""
row = await pool.fetchrow(
"SELECT document_type FROM documents WHERE id = $1::uuid",
document_id,
)
return row["document_type"] if row else None
async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]:
"""Fetch company info needed for exposure profile loading and interpolation."""
rows = await pool.fetch(
"""SELECT id, ticker, sector, industry, market_cap_bucket
FROM companies WHERE active = TRUE"""
)
return [dict(r) for r in rows]
async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str):
"""Load exposure profile for a company: manual > inferred > default.
Requirements: 4.1
"""
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
# Try manual or inferred profile from DB
row = await pool.fetchrow(
"""SELECT company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version
FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC LIMIT 1""",
company_id,
)
if row:
geo_mix = row["geographic_revenue_mix"]
if isinstance(geo_mix, str):
geo_mix = json.loads(geo_mix)
tier_val = row["market_position_tier"]
try:
tier = MarketPositionTier(tier_val)
except ValueError:
tier = MarketPositionTier.REGIONAL
return ExposureProfileSchema(
company_id=str(row["company_id"]),
geographic_revenue_mix=geo_mix or {},
supply_chain_regions=list(row["supply_chain_regions"] or []),
key_input_commodities=list(row["key_input_commodities"] or []),
regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []),
market_position_tier=tier,
export_dependency_pct=float(row["export_dependency_pct"] or 0.0),
source=row["source"] or "manual",
confidence=float(row["confidence"] or 1.0),
version=row["version"] or 1,
)
# Fall back to default profile
profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap")
profile.company_id = str(company_id)
return profile
async def _compute_and_persist_macro_impacts(
pool: asyncpg.Pool,
event,
companies: list[dict],
confidence_threshold: float = 0.4,
) -> list[str]:
"""Compute MacroImpactRecords for all tracked companies and persist non-zero ones.
Requirements: 4.1, 4.5
"""
# Filter low-confidence events
filtered = filter_low_confidence_events([event], confidence_threshold)
if not filtered:
logger.info("Event %s excluded: confidence %.3f below threshold %.3f",
event.event_id, event.confidence, confidence_threshold)
return []
records = []
for company in companies:
company_id = str(company["id"])
ticker = company["ticker"]
sector = company.get("sector") or ""
industry = company.get("industry") or ""
market_cap_bucket = company.get("market_cap_bucket") or "small_cap"
profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket)
record = compute_macro_impact_with_sector(event, profile, company_sector=sector)
record.ticker = ticker
record.company_id = company_id
if record.macro_impact_score > 0.0:
records.append(record)
if records:
ids = await persist_macro_impact_records(pool, records)
logger.info(
"Persisted %d macro impact records for event %s",
len(ids), event.event_id,
)
return [r.ticker for r in records]
return []
# Track consecutive macro classification failures for alerting (Requirement 10.4)
_macro_consecutive_failures = 0
_MACRO_FAILURE_ALERT_THRESHOLD = 3
async def _process_macro_classification(
*,
pool: asyncpg.Pool,
minio_client: Minio,
ollama: OllamaClient,
redis_client: aioredis.Redis,
document_id: str,
text: str,
company_id_map: dict[str, str],
confidence_threshold: float = 0.4,
) -> None:
"""Route a macro_event document to event classification, compute interpolation,
and trigger aggregation for affected tickers.
Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4
"""
global _macro_consecutive_failures
agg_queue = queue_key(QUEUE_AGGREGATION)
try:
event = await classify_global_event(
normalized_text=text,
document_id=document_id,
ollama_client=ollama,
pool=pool,
minio_client=minio_client,
)
logger.info(
"Classified macro event %s for doc %s: severity=%s types=%s",
event.event_id, document_id, event.severity, event.event_types,
)
# Reset failure counter on success
_macro_consecutive_failures = 0
# Load all tracked companies and compute macro impacts
companies = await _fetch_company_info(pool)
affected_tickers = await _compute_and_persist_macro_impacts(
pool, event, companies, confidence_threshold,
)
# Trigger aggregation for affected tickers (those with non-zero impact)
enqueued_tickers = set()
for ticker in affected_tickers:
if ticker not in enqueued_tickers:
await redis_client.rpush(
agg_queue,
json.dumps(inject_trace_context({
"ticker": ticker,
"macro_event_id": event.event_id,
})),
)
enqueued_tickers.add(ticker)
logger.info(
"Enqueued aggregation jobs for %d affected tickers after macro event %s",
len(enqueued_tickers), event.event_id,
)
except ValueError as e:
_macro_consecutive_failures += 1
logger.error("Macro event classification failed for doc %s: %s", document_id, e)
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
logger.critical(
"ALERT: Sustained macro classification failures (%d consecutive). "
"Continuing with company-only signals. Operator action required.",
_macro_consecutive_failures,
)
except Exception:
_macro_consecutive_failures += 1
logger.exception("Unexpected error classifying macro event for doc %s", document_id)
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
logger.critical(
"ALERT: Sustained macro classification failures (%d consecutive). "
"Continuing with company-only signals. Operator action required.",
_macro_consecutive_failures,
)
async def main() -> None:
config = load_config()
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
@@ -42,8 +242,10 @@ async def main() -> None:
ollama = OllamaClient(config.ollama)
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION)
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
agg_queue = queue_key(QUEUE_AGGREGATION)
logger.info("Extractor worker started, polling %s", queue)
confidence_threshold = config.macro.macro_confidence_threshold
logger.info("Extractor worker started, polling %s and %s", queue, macro_queue)
# Pre-load company ID map (refreshed periodically)
company_id_map = await _build_company_id_map(pool)
@@ -51,7 +253,13 @@ async def main() -> None:
try:
while True:
raw = await redis_client.lpop(queue)
# Check macro classification queue first (priority)
raw = await redis_client.lpop(macro_queue)
is_macro_job = raw is not None
if raw is None:
raw = await redis_client.lpop(queue)
if raw is None:
await asyncio.sleep(1)
continue
@@ -80,13 +288,35 @@ async def main() -> None:
except Exception as e:
logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e)
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
# Refresh company map every 100 jobs
refresh_counter += 1
if refresh_counter % 100 == 0:
company_id_map = await _build_company_id_map(pool)
# Route macro_event documents to event classification (Requirement 2.1)
doc_type = None
if is_macro_job:
doc_type = "macro_event"
else:
doc_type = await _fetch_document_type(pool, document_id)
if doc_type == "macro_event":
logger.info("Routing macro_event doc %s to event classifier", document_id)
await _process_macro_classification(
pool=pool,
minio_client=minio_client,
ollama=ollama,
redis_client=redis_client,
document_id=document_id,
text=text,
company_id_map=company_id_map,
confidence_threshold=confidence_threshold,
)
continue
# Standard extraction pipeline for non-macro documents
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
try:
# Pass all tracked tickers so the model can identify any mentioned companies
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
+9 -2
View File
@@ -10,6 +10,7 @@ from minio import Minio
from services.adapters.base import AdapterResult
from services.adapters.broker_adapter import AlpacaBrokerAdapter, TradingMode
from services.adapters.filings_adapter import SECEdgarAdapter
from services.adapters.macro_news_adapter import MacroNewsAdapter
from services.adapters.market_adapter import PolygonMarketAdapter
from services.adapters.news_adapter import PolygonNewsAdapter
from services.adapters.web_scrape_adapter import WebScrapeAdapter
@@ -69,11 +70,14 @@ async def process_job(
logger.warning("No adapter for source_type=%s", source_type)
return
# Macro sources may not have a company_id
company_id = job.get("company_id")
# Record ingestion run
run_id = await pool.fetchval(
"""INSERT INTO ingestion_runs (source_id, company_id, source_type, status)
VALUES ($1, $2, $3, 'running') RETURNING id""",
source_id, job["company_id"], source_type,
source_id, company_id, source_type,
)
try:
@@ -159,7 +163,7 @@ async def process_job(
# Link duplicate documents to this company if not already linked
company_id = job.get("company_id")
if company_id and deduped_count:
if company_id and deduped_count and source_type not in ("macro_news",):
from services.shared.metadata import persist_document_company_mention
for dup in dup_items:
existing_id = dup.get("_dedupe_existing_id")
@@ -234,6 +238,9 @@ async def main():
mode=TradingMode.LIVE if cfg.broker.mode == "live" else TradingMode.PAPER,
base_url=cfg.broker.base_url,
),
"macro_news": MacroNewsAdapter(
api_key=cfg.market_data.api_key,
),
}
logger.info("Ingestion worker started")
+21
View File
@@ -124,6 +124,27 @@ TABLE_SCHEMAS: dict[str, pa.Schema] = {
"model_performance": MODEL_PERFORMANCE_SCHEMA,
}
# Lazily register schemas defined in worker.py to avoid circular imports.
# These are added after the initial dict definition.
def _register_worker_schemas() -> None:
from services.lake_publisher.worker import (
COMPETITOR_RELATIONSHIPS_SCHEMA,
COMPETITIVE_SIGNALS_SCHEMA,
GLOBAL_EVENTS_SCHEMA,
MACRO_IMPACTS_SCHEMA,
TREND_PROJECTIONS_SCHEMA,
)
TABLE_SCHEMAS["competitor_relationships"] = COMPETITOR_RELATIONSHIPS_SCHEMA
TABLE_SCHEMAS["competitive_signals"] = COMPETITIVE_SIGNALS_SCHEMA
TABLE_SCHEMAS["global_events"] = GLOBAL_EVENTS_SCHEMA
TABLE_SCHEMAS["macro_impacts"] = MACRO_IMPACTS_SCHEMA
TABLE_SCHEMAS["trend_projections"] = TREND_PROJECTIONS_SCHEMA
try:
_register_worker_schemas()
except ImportError:
pass # worker.py not available in minimal test environments
@dataclass(frozen=True)
class IcebergTableDef:
+240
View File
@@ -39,12 +39,17 @@ from services.lake_publisher.worker import (
publish_document_extractions_batch,
publish_document_fact,
publish_documents_batch,
publish_global_event_fact,
publish_macro_impact_fact,
publish_market_bar,
publish_market_quote,
publish_pnl_daily,
publish_positions_daily_batch,
publish_trade_fill,
publish_trade_order,
publish_trend_projection_fact,
publish_competitor_relationship_fact,
publish_competitive_signal_fact,
)
from services.shared.config import load_config
from services.shared.db import get_minio, get_pg_pool, get_redis
@@ -164,6 +169,57 @@ ORDER BY di.created_at
LIMIT 500
"""
_FETCH_GLOBAL_EVENT = """
SELECT
ge.id, ge.event_types, ge.severity, ge.affected_regions,
ge.affected_sectors, ge.affected_commodities, ge.summary,
ge.estimated_duration, ge.confidence, ge.source_document_id,
ge.created_at
FROM global_events ge
WHERE ge.id = $1::uuid
"""
_FETCH_MACRO_IMPACTS_FOR_EVENT = """
SELECT
mir.event_id, mir.company_id, mir.ticker,
mir.macro_impact_score, mir.impact_direction,
mir.contributing_factors, mir.confidence, mir.computed_at
FROM macro_impact_records mir
WHERE mir.event_id = $1::uuid
"""
_FETCH_TREND_PROJECTION = """
SELECT
tp.id, tp.trend_window_id, tp.projected_direction,
tp.projected_strength, tp.projected_confidence,
tp.projection_horizon, tp.driving_factors,
tp.macro_contribution_pct, tp.diverges_from_current,
tp.computed_at,
tw.ticker
FROM trend_projections tp
JOIN trend_windows tw ON tw.id = tp.trend_window_id
WHERE tp.trend_window_id = $1::uuid
"""
_FETCH_COMPETITOR_RELATIONSHIP = """
SELECT
cr.id, cr.company_a_id, cr.company_b_id,
cr.relationship_type, cr.strength, cr.bidirectional,
cr.source, cr.active, cr.created_at
FROM competitor_relationships cr
WHERE cr.id = $1::uuid
"""
_FETCH_COMPETITIVE_SIGNALS_FOR_DOCUMENT = """
SELECT
csr.id, csr.source_document_id, csr.source_ticker,
csr.target_ticker, csr.catalyst_type, csr.pattern_confidence,
csr.signal_direction, csr.signal_strength,
csr.relationship_strength, csr.computed_at
FROM competitive_signal_records csr
WHERE csr.source_document_id = $1::uuid
"""
# ---------------------------------------------------------------------------
# Job handlers — each transforms operational rows into lake facts
@@ -510,6 +566,165 @@ async def publish_bulk_extractions_job(
return [ref] if ref else []
async def publish_global_event_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> str:
"""Publish a global event fact from PostgreSQL to the lake."""
row = await pool.fetchrow(_FETCH_GLOBAL_EVENT, entity_id)
if row is None:
logger.warning("Global event %s not found, skipping lake publish", entity_id)
return ""
event_types = row["event_types"] or []
affected_regions = row["affected_regions"] or []
affected_sectors = row["affected_sectors"] or []
affected_commodities = row["affected_commodities"] or []
return publish_global_event_fact(
client=minio_client,
event_id=str(row["id"]),
event_types=list(event_types),
severity=row["severity"] or "low",
affected_regions=list(affected_regions),
affected_sectors=list(affected_sectors),
affected_commodities=list(affected_commodities),
summary=row["summary"] or "",
estimated_duration=row["estimated_duration"] or "short_term",
confidence=float(row["confidence"] or 0.0),
source_document_id=str(row["source_document_id"]) if row["source_document_id"] else "",
created_at=row["created_at"],
)
async def publish_macro_impacts_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> list[str]:
"""Publish macro impact facts for a global event from PostgreSQL to the lake."""
rows = await pool.fetch(_FETCH_MACRO_IMPACTS_FOR_EVENT, entity_id)
if not rows:
logger.info("No macro impact records for event %s", entity_id)
return []
refs: list[str] = []
for row in rows:
factors = row["contributing_factors"]
if isinstance(factors, str):
try:
factors = json.loads(factors)
except (json.JSONDecodeError, TypeError):
factors = [factors] if factors else []
elif factors is None:
factors = []
ref = publish_macro_impact_fact(
client=minio_client,
event_id=str(row["event_id"]),
company_id=str(row["company_id"]),
ticker=row["ticker"],
macro_impact_score=float(row["macro_impact_score"] or 0.0),
impact_direction=row["impact_direction"] or "neutral",
contributing_factors=list(factors),
confidence=float(row["confidence"] or 0.0),
computed_at=row["computed_at"],
)
refs.append(ref)
return refs
async def publish_trend_projection_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> str:
"""Publish a trend projection fact from PostgreSQL to the lake."""
row = await pool.fetchrow(_FETCH_TREND_PROJECTION, entity_id)
if row is None:
logger.warning("Trend projection for window %s not found", entity_id)
return ""
factors = row["driving_factors"]
if isinstance(factors, str):
try:
factors = json.loads(factors)
except (json.JSONDecodeError, TypeError):
factors = [factors] if factors else []
elif factors is None:
factors = []
return publish_trend_projection_fact(
client=minio_client,
trend_window_id=str(row["trend_window_id"]),
ticker=row["ticker"] or "",
projected_direction=row["projected_direction"] or "neutral",
projected_strength=float(row["projected_strength"] or 0.0),
projected_confidence=float(row["projected_confidence"] or 0.0),
projection_horizon=row["projection_horizon"] or "7d",
driving_factors=list(factors),
macro_contribution_pct=float(row["macro_contribution_pct"] or 0.0),
diverges_from_current=bool(row["diverges_from_current"]),
computed_at=row["computed_at"],
)
async def publish_competitor_relationship_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> str:
"""Publish a competitor relationship fact from PostgreSQL to the lake."""
row = await pool.fetchrow(_FETCH_COMPETITOR_RELATIONSHIP, entity_id)
if row is None:
logger.warning("Competitor relationship %s not found, skipping lake publish", entity_id)
return ""
return publish_competitor_relationship_fact(
client=minio_client,
relationship_id=str(row["id"]),
company_a_id=str(row["company_a_id"]),
company_b_id=str(row["company_b_id"]),
relationship_type=row["relationship_type"],
strength=float(row["strength"]),
bidirectional=bool(row["bidirectional"]),
source=row["source"],
active=bool(row["active"]),
created_at=row["created_at"],
)
async def publish_competitive_signals_job(
pool: asyncpg.Pool,
minio_client: Minio,
entity_id: str,
) -> list[str]:
"""Publish competitive signal facts for a document from PostgreSQL to the lake."""
rows = await pool.fetch(_FETCH_COMPETITIVE_SIGNALS_FOR_DOCUMENT, entity_id)
if not rows:
logger.info("No competitive signals for document %s", entity_id)
return []
refs: list[str] = []
for row in rows:
ref = publish_competitive_signal_fact(
client=minio_client,
signal_id=str(row["id"]),
source_document_id=str(row["source_document_id"]),
source_ticker=row["source_ticker"],
target_ticker=row["target_ticker"],
catalyst_type=row["catalyst_type"],
pattern_confidence=float(row["pattern_confidence"]),
signal_direction=row["signal_direction"],
signal_strength=float(row["signal_strength"]),
relationship_strength=float(row["relationship_strength"]),
computed_at=row["computed_at"],
)
refs.append(ref)
return refs
# ---------------------------------------------------------------------------
# Job dispatcher
# ---------------------------------------------------------------------------
@@ -525,6 +740,11 @@ JOB_TYPES = {
"company_event",
"bulk_documents",
"bulk_extractions",
"global_event",
"macro_impact",
"trend_projection",
"competitor_relationship",
"competitive_signal",
}
@@ -594,6 +814,26 @@ async def dispatch_job(
refs = await publish_bulk_extractions_job(pool, minio_client, since)
result["refs"] = refs
elif job_type == "global_event":
ref = await publish_global_event_job(pool, minio_client, entity_id)
result["refs"] = [ref] if ref else []
elif job_type == "macro_impact":
refs = await publish_macro_impacts_job(pool, minio_client, entity_id)
result["refs"] = refs
elif job_type == "trend_projection":
ref = await publish_trend_projection_job(pool, minio_client, entity_id)
result["refs"] = [ref] if ref else []
elif job_type == "competitor_relationship":
ref = await publish_competitor_relationship_job(pool, minio_client, entity_id)
result["refs"] = [ref] if ref else []
elif job_type == "competitive_signal":
refs = await publish_competitive_signals_job(pool, minio_client, entity_id)
result["refs"] = refs
else:
result["error"] = f"Unknown job_type: {job_type}"
logger.warning("Unknown lake publish job type: %s", job_type)
+5
View File
@@ -55,6 +55,11 @@ TABLE_PARTITIONS: dict[str, PartitionSpec] = {
"pnl_daily": PartitionSpec("pnl_daily"),
"prediction_vs_outcome": PartitionSpec("prediction_vs_outcome", extra_keys=("model_version",)),
"model_performance": PartitionSpec("model_performance", extra_keys=("model_version",)),
"global_events": PartitionSpec("global_events"),
"macro_impacts": PartitionSpec("macro_impacts", extra_keys=("ticker",)),
"trend_projections": PartitionSpec("trend_projections", extra_keys=("ticker",)),
"competitor_relationships": PartitionSpec("competitor_relationships"),
"competitive_signals": PartitionSpec("competitive_signals", extra_keys=("target_ticker",)),
}
+370
View File
@@ -1226,3 +1226,373 @@ def publish_prediction_vs_outcome_batch(
) -> str:
"""Publish a batch of prediction vs outcome rows as a single Parquet file."""
return _publish_batch(client, "prediction_vs_outcome", rows, PREDICTION_VS_OUTCOME_SCHEMA, dt)
# --- global_events fact table ---
GLOBAL_EVENTS_SCHEMA = pa.schema([
("event_id", pa.string()),
("event_types", pa.string()),
("severity", pa.string()),
("affected_regions", pa.string()),
("affected_sectors", pa.string()),
("affected_commodities", pa.string()),
("summary", pa.string()),
("estimated_duration", pa.string()),
("confidence", pa.float64()),
("source_document_id", pa.string()),
("created_at", pa.timestamp("us", tz="UTC")),
("dt", pa.date32()),
])
def publish_global_event_fact(
client: Minio,
event_id: str,
event_types: list[str],
severity: str,
affected_regions: list[str],
affected_sectors: list[str],
affected_commodities: list[str],
summary: str,
estimated_duration: str,
confidence: float,
source_document_id: str,
created_at: datetime,
) -> str:
"""Publish a single global event fact to MinIO.
Writes a Parquet file to:
s3://stonks-lakehouse/warehouse/global_events/dt={date}/part-{uuid}.parquet
Returns the s3:// URI of the written object.
Requirements: 7.3, 12.6
Design ref: Analytical Lake Datasets (lake.global_events)
"""
row: dict[str, object] = {
"event_id": event_id,
"event_types": ", ".join(event_types),
"severity": severity,
"affected_regions": ", ".join(affected_regions),
"affected_sectors": ", ".join(affected_sectors),
"affected_commodities": ", ".join(affected_commodities),
"summary": summary,
"estimated_duration": estimated_duration,
"confidence": confidence,
"source_document_id": source_document_id,
"created_at": created_at,
**partition_values(created_at),
}
table = pa.Table.from_pylist([row], schema=GLOBAL_EVENTS_SCHEMA)
parquet_bytes = _write_parquet_bytes(table)
path = _partition_path("global_events", created_at)
_put_lakehouse_object(client, "global_events", path, parquet_bytes)
ref = s3_uri(path)
logger.info("Published global_event fact %s: %s", event_id, ref)
return ref
# --- macro_impacts fact table ---
MACRO_IMPACTS_SCHEMA = pa.schema([
("event_id", pa.string()),
("company_id", pa.string()),
("ticker", pa.string()),
("macro_impact_score", pa.float64()),
("impact_direction", pa.string()),
("contributing_factors", pa.string()),
("confidence", pa.float64()),
("computed_at", pa.timestamp("us", tz="UTC")),
("dt", pa.date32()),
])
def publish_macro_impact_fact(
client: Minio,
event_id: str,
company_id: str,
ticker: str,
macro_impact_score: float,
impact_direction: str,
contributing_factors: list[str],
confidence: float,
computed_at: datetime,
) -> str:
"""Publish a single macro impact fact to MinIO.
Writes a Parquet file to:
s3://stonks-lakehouse/warehouse/macro_impacts/dt={date}/ticker={ticker}/part-{uuid}.parquet
Returns the s3:// URI of the written object.
Requirements: 7.3, 12.6
Design ref: Analytical Lake Datasets (lake.macro_impacts)
"""
extra = {"ticker": ticker}
row: dict[str, object] = {
"event_id": event_id,
"company_id": company_id,
"ticker": ticker,
"macro_impact_score": macro_impact_score,
"impact_direction": impact_direction,
"contributing_factors": ", ".join(contributing_factors),
"confidence": confidence,
"computed_at": computed_at,
**partition_values(computed_at, extra),
}
table = pa.Table.from_pylist([row], schema=MACRO_IMPACTS_SCHEMA)
parquet_bytes = _write_parquet_bytes(table)
path = _partition_path("macro_impacts", computed_at, extra_partitions=extra)
_put_lakehouse_object(client, "macro_impacts", path, parquet_bytes)
ref = s3_uri(path)
logger.info("Published macro_impact fact for %s/%s: %s", ticker, event_id, ref)
return ref
# --- trend_projections fact table ---
TREND_PROJECTIONS_SCHEMA = pa.schema([
("trend_window_id", pa.string()),
("ticker", pa.string()),
("projected_direction", pa.string()),
("projected_strength", pa.float64()),
("projected_confidence", pa.float64()),
("projection_horizon", pa.string()),
("driving_factors", pa.string()),
("macro_contribution_pct", pa.float64()),
("diverges_from_current", pa.bool_()),
("computed_at", pa.timestamp("us", tz="UTC")),
("dt", pa.date32()),
])
def publish_trend_projection_fact(
client: Minio,
trend_window_id: str,
ticker: str,
projected_direction: str,
projected_strength: float,
projected_confidence: float,
projection_horizon: str,
driving_factors: list[str],
macro_contribution_pct: float,
diverges_from_current: bool,
computed_at: datetime,
) -> str:
"""Publish a single trend projection fact to MinIO.
Writes a Parquet file to:
s3://stonks-lakehouse/warehouse/trend_projections/dt={date}/ticker={ticker}/part-{uuid}.parquet
Returns the s3:// URI of the written object.
Requirements: 7.3, 12.6
Design ref: Analytical Lake Datasets (lake.trend_projections)
"""
extra = {"ticker": ticker}
row: dict[str, object] = {
"trend_window_id": trend_window_id,
"ticker": ticker,
"projected_direction": projected_direction,
"projected_strength": projected_strength,
"projected_confidence": projected_confidence,
"projection_horizon": projection_horizon,
"driving_factors": ", ".join(driving_factors),
"macro_contribution_pct": macro_contribution_pct,
"diverges_from_current": diverges_from_current,
"computed_at": computed_at,
**partition_values(computed_at, extra),
}
table = pa.Table.from_pylist([row], schema=TREND_PROJECTIONS_SCHEMA)
parquet_bytes = _write_parquet_bytes(table)
path = _partition_path("trend_projections", computed_at, extra_partitions=extra)
_put_lakehouse_object(client, "trend_projections", path, parquet_bytes)
ref = s3_uri(path)
logger.info("Published trend_projection fact for %s: %s", ticker, ref)
return ref
# --- Batch publishers for macro fact tables ---
def publish_global_events_batch(
client: Minio,
rows: list[dict[str, object]],
dt: datetime,
) -> str:
"""Publish a batch of global event rows as a single Parquet file."""
return _publish_batch(client, "global_events", rows, GLOBAL_EVENTS_SCHEMA, dt)
def publish_macro_impacts_batch(
client: Minio,
rows: list[dict[str, object]],
dt: datetime,
ticker: str = "",
) -> str:
"""Publish a batch of macro impact rows as a single Parquet file."""
extra = {"ticker": ticker} if ticker else None
return _publish_batch(client, "macro_impacts", rows, MACRO_IMPACTS_SCHEMA, dt, extra)
def publish_trend_projections_batch(
client: Minio,
rows: list[dict[str, object]],
dt: datetime,
ticker: str = "",
) -> str:
"""Publish a batch of trend projection rows as a single Parquet file."""
extra = {"ticker": ticker} if ticker else None
return _publish_batch(client, "trend_projections", rows, TREND_PROJECTIONS_SCHEMA, dt, extra)
# --- competitor_relationships fact table ---
COMPETITOR_RELATIONSHIPS_SCHEMA = pa.schema([
("id", pa.string()),
("company_a_id", pa.string()),
("company_b_id", pa.string()),
("relationship_type", pa.string()),
("strength", pa.float64()),
("bidirectional", pa.bool_()),
("source", pa.string()),
("active", pa.bool_()),
("created_at", pa.timestamp("us", tz="UTC")),
("dt", pa.date32()),
])
def publish_competitor_relationship_fact(
client: Minio,
relationship_id: str,
company_a_id: str,
company_b_id: str,
relationship_type: str,
strength: float,
bidirectional: bool,
source: str,
active: bool,
created_at: datetime,
) -> str:
"""Publish a single competitor relationship fact to MinIO.
Writes a Parquet file to:
s3://stonks-lakehouse/warehouse/competitor_relationships/dt={date}/part-{uuid}.parquet
Returns the s3:// URI of the written object.
Requirements: 7.3
Design ref: Analytical Lake Datasets (lake.competitor_relationships)
"""
row: dict[str, object] = {
"id": relationship_id,
"company_a_id": company_a_id,
"company_b_id": company_b_id,
"relationship_type": relationship_type,
"strength": strength,
"bidirectional": bidirectional,
"source": source,
"active": active,
"created_at": created_at,
**partition_values(created_at),
}
table = pa.Table.from_pylist([row], schema=COMPETITOR_RELATIONSHIPS_SCHEMA)
parquet_bytes = _write_parquet_bytes(table)
path = _partition_path("competitor_relationships", created_at)
_put_lakehouse_object(client, "competitor_relationships", path, parquet_bytes)
ref = s3_uri(path)
logger.info("Published competitor_relationship fact %s: %s", relationship_id, ref)
return ref
def publish_competitor_relationships_batch(
client: Minio,
rows: list[dict[str, object]],
dt: datetime,
) -> str:
"""Publish a batch of competitor relationship rows as a single Parquet file."""
return _publish_batch(client, "competitor_relationships", rows, COMPETITOR_RELATIONSHIPS_SCHEMA, dt)
# --- competitive_signals fact table ---
COMPETITIVE_SIGNALS_SCHEMA = pa.schema([
("id", pa.string()),
("source_document_id", pa.string()),
("source_ticker", pa.string()),
("target_ticker", pa.string()),
("catalyst_type", pa.string()),
("pattern_confidence", pa.float64()),
("signal_direction", pa.string()),
("signal_strength", pa.float64()),
("relationship_strength", pa.float64()),
("computed_at", pa.timestamp("us", tz="UTC")),
("dt", pa.date32()),
])
def publish_competitive_signal_fact(
client: Minio,
signal_id: str,
source_document_id: str,
source_ticker: str,
target_ticker: str,
catalyst_type: str,
pattern_confidence: float,
signal_direction: str,
signal_strength: float,
relationship_strength: float,
computed_at: datetime,
) -> str:
"""Publish a single competitive signal fact to MinIO.
Writes a Parquet file to:
s3://stonks-lakehouse/warehouse/competitive_signals/dt={date}/target_ticker={ticker}/part-{uuid}.parquet
Returns the s3:// URI of the written object.
Requirements: 7.4
Design ref: Analytical Lake Datasets (lake.competitive_signals)
"""
extra = {"target_ticker": target_ticker}
row: dict[str, object] = {
"id": signal_id,
"source_document_id": source_document_id,
"source_ticker": source_ticker,
"target_ticker": target_ticker,
"catalyst_type": catalyst_type,
"pattern_confidence": pattern_confidence,
"signal_direction": signal_direction,
"signal_strength": signal_strength,
"relationship_strength": relationship_strength,
"computed_at": computed_at,
**partition_values(computed_at, extra),
}
table = pa.Table.from_pylist([row], schema=COMPETITIVE_SIGNALS_SCHEMA)
parquet_bytes = _write_parquet_bytes(table)
path = _partition_path("competitive_signals", computed_at, extra_partitions=extra)
_put_lakehouse_object(client, "competitive_signals", path, parquet_bytes)
ref = s3_uri(path)
logger.info("Published competitive_signal fact for %s%s: %s", source_ticker, target_ticker, ref)
return ref
def publish_competitive_signals_batch(
client: Minio,
rows: list[dict[str, object]],
dt: datetime,
target_ticker: str = "",
) -> str:
"""Publish a batch of competitive signal rows as a single Parquet file."""
extra = {"target_ticker": target_ticker} if target_ticker else None
return _publish_batch(client, "competitive_signals", rows, COMPETITIVE_SIGNALS_SCHEMA, dt, extra)
+14 -2
View File
@@ -35,7 +35,7 @@ from services.shared.metrics import (
PARSE_LOW_QUALITY_TOTAL,
PARSE_QUALITY_SCORE,
)
from services.shared.redis_keys import QUEUE_EXTRACTION, QUEUE_PARSING, queue_key
from services.shared.redis_keys import QUEUE_EXTRACTION, QUEUE_MACRO_CLASSIFICATION, QUEUE_PARSING, queue_key
from services.shared.storage import upload_normalized_text, upload_parser_output
logger = logging.getLogger("parser_worker")
@@ -210,7 +210,19 @@ async def process_job(
# Only enqueue for extraction if quality is acceptable
if parsed.confidence != "low":
await rds.rpush(queue_key(QUEUE_EXTRACTION), json.dumps(inject_trace_context({
# Route macro_event documents to the macro classification queue
# instead of the standard extraction queue (Requirement 2.1)
doc_type_row = await pool.fetchrow(
"SELECT document_type FROM documents WHERE id = $1::uuid", doc_id,
)
doc_type = doc_type_row["document_type"] if doc_type_row else None
if doc_type == "macro_event":
target_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
else:
target_queue = queue_key(QUEUE_EXTRACTION)
await rds.rpush(target_queue, json.dumps(inject_trace_context({
"document_id": doc_id,
"ticker": ticker,
"normalized_text": text[:32000],
+115
View File
@@ -32,6 +32,8 @@ class SuppressionReason(str, Enum):
LOW_SOURCE_DIVERSITY = "low_source_diversity"
HIGH_EXTRACTION_FAILURE_RATE = "high_extraction_failure_rate"
INSUFFICIENT_VALID_DOCUMENTS = "insufficient_valid_documents"
MACRO_ONLY_SIGNAL = "macro_only_signal"
PATTERN_ONLY_SIGNAL = "pattern_only_signal"
@dataclass(frozen=True)
@@ -240,3 +242,116 @@ def evaluate_suppression(
data_quality_score=quality_score,
context=ctx,
)
# ---------------------------------------------------------------------------
# Macro-only suppression (Requirements: 10.3)
# ---------------------------------------------------------------------------
MACRO_ONLY_CAVEAT = (
"[Macro-only signal] This trend direction is driven solely by macro/geopolitical "
"signals with no supporting company-specific evidence. Recommendation is "
"informational only and should not be used for automated trading decisions."
)
def evaluate_macro_only_suppression(
summary: TrendSummary,
macro_signal_count: int,
company_signal_count: int,
) -> bool:
"""Evaluate whether a recommendation should be suppressed due to macro-only signals.
When macro signals are the sole basis for a trend direction change
(no supporting company-specific signals), the recommendation should
be forced to informational mode with a macro-only caveat.
Args:
summary: The trend summary to evaluate.
macro_signal_count: Number of macro signals contributing to the trend.
company_signal_count: Number of company-specific signals contributing.
Returns:
True if the recommendation should be suppressed (macro-only), False otherwise.
Requirements: 10.3
"""
# No macro signals means no macro-only suppression
if macro_signal_count <= 0:
return False
# If there are company-specific signals, no suppression needed
if company_signal_count > 0:
return False
# Macro signals are the sole basis — suppress
logger.info(
"Macro-only suppression triggered for %s/%s: "
"macro_signals=%d, company_signals=%d, direction=%s",
summary.entity_id,
summary.window.value,
macro_signal_count,
company_signal_count,
summary.trend_direction.value,
)
return True
# ---------------------------------------------------------------------------
# Pattern-only suppression (Requirements: 9.3)
# ---------------------------------------------------------------------------
PATTERN_ONLY_CAVEAT = (
"[Pattern-only signal] This trend direction is driven solely by historical "
"pattern and competitive signals with no supporting company-specific or macro "
"evidence. Recommendation is informational only."
)
def evaluate_pattern_only_suppression(
summary: TrendSummary,
pattern_signal_count: int,
company_signal_count: int,
macro_signal_count: int,
) -> bool:
"""Evaluate whether a recommendation should be suppressed due to pattern-only signals.
When pattern-based signals are the sole basis for a trend direction change
(no supporting company-specific or macro signals), the recommendation should
be forced to informational mode with a pattern-only caveat.
Args:
summary: The trend summary to evaluate.
pattern_signal_count: Number of pattern/competitive signals contributing.
company_signal_count: Number of company-specific signals contributing.
macro_signal_count: Number of macro signals contributing.
Returns:
True if the recommendation should be suppressed (pattern-only), False otherwise.
Requirements: 9.3
"""
# No pattern signals means no pattern-only suppression
if pattern_signal_count <= 0:
return False
# If there are company-specific signals, no suppression needed
if company_signal_count > 0:
return False
# If there are macro signals, no suppression needed
if macro_signal_count > 0:
return False
# Pattern signals are the sole basis — suppress
logger.info(
"Pattern-only suppression triggered for %s/%s: "
"pattern_signals=%d, company_signals=%d, macro_signals=%d, direction=%s",
summary.entity_id,
summary.window.value,
pattern_signal_count,
company_signal_count,
macro_signal_count,
summary.trend_direction.value,
)
return True
+123 -14
View File
@@ -31,6 +31,7 @@ from services.recommendation.thesis_llm import (
THESIS_PROMPT_VERSION,
rewrite_thesis_with_llm,
)
from services.aggregation.projection import TrendProjection
from services.shared.config import OllamaConfig
from services.shared.metrics import (
RECOMMENDATION_CONFIDENCE,
@@ -178,6 +179,63 @@ async def fetch_latest_trend(
return _parse_trend_row(row)
# ---------------------------------------------------------------------------
# Fetch latest trend projection for a ticker + window
# ---------------------------------------------------------------------------
_LATEST_PROJECTION_QUERY = """
SELECT
tp.projected_direction, tp.projected_strength, tp.projected_confidence,
tp.projection_horizon, tp.driving_factors, tp.macro_contribution_pct,
tp.diverges_from_current, tp.computed_at
FROM trend_projections tp
JOIN trend_windows tw ON tw.id = tp.trend_window_id
WHERE tw.entity_id = $1 AND tw."window" = $2
ORDER BY tp.computed_at DESC
LIMIT 1
"""
async def fetch_latest_projection(
pool: asyncpg.Pool,
ticker: str,
window: str,
) -> TrendProjection | None:
"""Fetch the most recent trend projection for a ticker and window.
Returns None if no projection exists. Low-confidence projections
are returned with low_confidence=True so callers can decide whether
to use them (Requirement 12.9).
"""
try:
row = await pool.fetchrow(_LATEST_PROJECTION_QUERY, ticker, window)
if row is None:
return None
driving_factors = row["driving_factors"]
if isinstance(driving_factors, str):
driving_factors = json.loads(driving_factors)
proj = TrendProjection(
projected_direction=row["projected_direction"],
projected_strength=float(row["projected_strength"]),
projected_confidence=float(row["projected_confidence"]),
projection_horizon=row["projection_horizon"],
driving_factors=driving_factors or [],
macro_contribution_pct=float(row["macro_contribution_pct"] or 0.0),
diverges_from_current=bool(row["diverges_from_current"]),
computed_at=row["computed_at"],
low_confidence=float(row["projected_confidence"]) < 0.3,
)
return proj
except Exception:
logger.warning(
"Failed to fetch projection for %s/%s — continuing without projection",
ticker, window, exc_info=True,
)
return None
# ---------------------------------------------------------------------------
# Build thesis from trend summary (deterministic, no LLM)
# ---------------------------------------------------------------------------
@@ -186,11 +244,16 @@ async def fetch_latest_trend(
def build_thesis(
summary: TrendSummary,
result: EligibilityResult,
projection: TrendProjection | None = None,
) -> str:
"""Generate a deterministic thesis string from trend data.
This is the descriptive analysis portion (Requirement 7.2).
The LLM wording layer is a separate optional task.
When a TrendProjection is provided and is not low-confidence,
the thesis incorporates the projected direction and key driving
factors (Requirement 12.8).
"""
direction = summary.trend_direction.value
ticker = summary.entity_id
@@ -218,6 +281,27 @@ def build_thesis(
+ f"(contradiction score: {summary.contradiction_score:.2f})."
)
# Trend projection (Requirement 12.8)
if projection is not None and not projection.low_confidence:
proj_dir = projection.projected_direction
proj_str = projection.projected_strength
parts.append(
f"Forward projection ({projection.projection_horizon}): "
f"{proj_dir} at strength {proj_str:.2f}."
)
# Include top driving factors
non_divergence_factors = [
f for f in projection.driving_factors
if not f.startswith("DIVERGENCE:")
]
if non_divergence_factors:
factors_str = "; ".join(non_divergence_factors[:2])
parts.append(f"Key drivers: {factors_str}.")
if projection.diverges_from_current:
parts.append(
f"Note: projection diverges from current {direction} trend."
)
# Risks
if summary.material_risks:
risk_str = "; ".join(summary.material_risks[:2])
@@ -290,6 +374,7 @@ def build_recommendation(
reference_time: datetime | None = None,
llm_thesis: str | None = None,
suppression_result: SuppressionResult | None = None,
projection: TrendProjection | None = None,
) -> Recommendation:
"""Assemble a Recommendation object from a trend summary and eligibility result.
@@ -302,6 +387,10 @@ def build_recommendation(
If ``suppression_result`` indicates suppression, a suppression note
is appended to the thesis for audit visibility (Requirement 7.4).
If ``projection`` is provided and is not low-confidence, the thesis
incorporates projected direction and driving factors (Requirement 12.8).
The time_horizon may be refined based on the projection horizon.
"""
if reference_time is None:
reference_time = datetime.now(timezone.utc)
@@ -309,7 +398,7 @@ def build_recommendation(
# Combine evidence refs — supporting first, then opposing
evidence_refs = list(summary.top_supporting_evidence) + list(summary.top_opposing_evidence)
deterministic_thesis = build_thesis(summary, result)
deterministic_thesis = build_thesis(summary, result, projection=projection)
risk_class = classify_risk(summary, result)
# Use LLM-rewritten thesis if available, otherwise deterministic
@@ -324,6 +413,13 @@ def build_recommendation(
f"reasons={', '.join(reason_strs)})]"
)
# Determine time_horizon — refine with projection horizon if available
# (Requirement 12.8)
time_horizon = result.time_horizon
if projection is not None and not projection.low_confidence:
# Append projection horizon context to time_horizon
time_horizon = f"{result.time_horizon} (proj:{projection.projection_horizon})"
# Track whether the thesis was LLM-generated for audit
if llm_thesis:
provider = "ollama"
@@ -339,7 +435,7 @@ def build_recommendation(
action=result.action,
mode=result.mode,
confidence=summary.confidence,
time_horizon=result.time_horizon,
time_horizon=time_horizon,
thesis=f"[risk:{risk_class}] {thesis_body}",
invalidation_conditions=result.invalidation_conditions,
position_sizing=PositionSizing(
@@ -574,12 +670,13 @@ async def generate_recommendation(
Steps:
1. Fetch the latest trend summary for the ticker + window.
2. Evaluate data quality suppression (Requirement 7.4).
3. Evaluate eligibility using deterministic rules.
4. Build a Recommendation object with thesis and evidence.
2. Fetch the latest trend projection (Requirement 12.8, 12.9).
3. Evaluate data quality suppression (Requirement 7.4).
4. Evaluate eligibility using deterministic rules.
5. Build a Recommendation object with thesis and evidence.
- If ``ollama_config`` is provided, the deterministic thesis is
rewritten into analyst-quality prose via the LLM wording layer.
5. Persist the recommendation and evidence citations.
6. Persist the recommendation and evidence citations.
Returns the Recommendation, or None if no trend data exists.
"""
@@ -595,13 +692,23 @@ async def generate_recommendation(
logger.info("No trend data for %s/%s — skipping recommendation", ticker, window)
return None
# 2. Evaluate data quality suppression (Requirement 7.4)
# 2. Fetch latest trend projection (Requirement 12.8, 12.9)
projection = await fetch_latest_projection(pool, ticker, window)
# Exclude low-confidence projections from influencing recommendation
# eligibility (Requirement 12.9). The projection is still passed to
# build_recommendation for informational display, but marked as
# low_confidence so it won't affect thesis or time_horizon.
effective_projection = projection
if projection is not None and projection.low_confidence:
effective_projection = projection # still passed, but build_thesis checks low_confidence
# 3. Evaluate data quality suppression (Requirement 7.4)
quality_ctx = await fetch_data_quality_context(pool, ticker, window)
suppression = evaluate_suppression(
summary, quality_ctx=quality_ctx, config=sup_cfg, reference_time=reference_time,
)
# 3. Evaluate eligibility
# 4. Evaluate eligibility
result = evaluate_eligibility(summary, cfg)
# Apply suppression: force mode to informational if suppressed
@@ -616,10 +723,10 @@ async def generate_recommendation(
invalidation_conditions=result.invalidation_conditions,
)
# 4. Optional LLM thesis rewrite
# 5. Optional LLM thesis rewrite
llm_thesis: str | None = None
if ollama_config is not None:
deterministic_thesis = build_thesis(summary, result)
deterministic_thesis = build_thesis(summary, result, projection=effective_projection)
llm_thesis = await rewrite_thesis_with_llm(
deterministic_thesis=deterministic_thesis,
summary=summary,
@@ -630,13 +737,14 @@ async def generate_recommendation(
if llm_thesis == deterministic_thesis:
llm_thesis = None
# 5. Build recommendation
# 6. Build recommendation
rec = build_recommendation(
summary, result, reference_time, llm_thesis=llm_thesis,
suppression_result=suppression,
projection=effective_projection,
)
# 6. Persist recommendation, evidence citations, and risk evaluation
# 7. Persist recommendation, evidence citations, and risk evaluation
rec_id = await persist_recommendation(
pool,
rec,
@@ -645,7 +753,7 @@ async def generate_recommendation(
eligibility_result=result,
)
# 7. Publish prediction facts to analytical tables (Requirement 9.4)
# 8. Publish prediction facts to analytical tables (Requirement 9.4)
if minio_client is not None:
try:
lake_refs = publish_recommendation_facts(
@@ -667,10 +775,11 @@ async def generate_recommendation(
logger.info(
"Generated recommendation %s for %s: action=%s mode=%s confidence=%.3f "
"eligible=%s suppressed=%s quality_score=%.3f llm_thesis=%s",
"eligible=%s suppressed=%s quality_score=%.3f llm_thesis=%s projection=%s",
rec_id, ticker, rec.action.value, rec.mode.value, rec.confidence,
result.eligible, suppression.suppressed, suppression.data_quality_score,
llm_thesis is not None,
projection.projected_direction if projection else "none",
)
# Prometheus metrics
+78 -5
View File
@@ -50,6 +50,7 @@ DEFAULT_CADENCES: dict[str, int] = {
"filings_api": 3600,
"web_scrape": 1800,
"broker": 30,
"macro_news": 600,
}
# Default rate limits per source type (requests per minute)
@@ -59,6 +60,7 @@ DEFAULT_RATE_LIMITS: dict[str, int] = {
"filings_api": 10,
"web_scrape": 10,
"broker": 60,
"macro_news": 10,
}
# How long to wait before retrying a failed source (seconds)
@@ -141,9 +143,9 @@ def build_job_payload(
"""Build the ingestion job payload for a source."""
return {
"source_id": str(source["source_id"]),
"company_id": str(source["company_id"]),
"ticker": source["ticker"],
"legal_name": source["legal_name"],
"company_id": str(source["company_id"]) if source.get("company_id") else None,
"ticker": source.get("ticker") or "",
"legal_name": source.get("legal_name") or "",
"aliases": aliases,
"source_type": source["source_type"],
"source_name": source["source_name"],
@@ -183,7 +185,7 @@ async def check_rate_limit(
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch all active sources joined with their active companies."""
"""Fetch all active company-specific sources joined with their active companies."""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
@@ -196,10 +198,33 @@ async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
FROM sources s
JOIN companies c ON s.company_id = c.id
WHERE s.active = TRUE AND c.active = TRUE
AND s.source_type != 'macro_news'
ORDER BY s.source_type, c.ticker"""
)
async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
"""Fetch all active macro news sources.
Macro sources are not company-specific — they have source_type='macro_news'
and may have company_id NULL. They are scheduled independently from
company-specific sources.
Requirements: 1.1
"""
return await pool.fetch(
"""SELECT s.id AS source_id,
s.company_id,
s.source_type,
s.source_name,
s.config,
s.credibility_score
FROM sources s
WHERE s.active = TRUE AND s.source_type = 'macro_news'
ORDER BY s.source_name"""
)
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
"""Fetch all aliases for a company."""
rows = await pool.fetch(
@@ -287,9 +312,57 @@ async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
source_type, src["ticker"], src["source_name"],
)
# --- Schedule macro news sources (Requirement 1.1) ---
macro_sources = await fetch_macro_sources(pool)
for src in macro_sources:
source_id = src["source_id"]
source_type = src["source_type"]
source_config = _ensure_dict(src["config"])
last_run = await fetch_last_run(pool, source_id)
last_completed_at = None
last_status = None
retry_count = 0
next_retry_at = None
if last_run:
last_status = last_run["status"]
last_completed_at = last_run["completed_at"] or last_run["started_at"]
retry_count = last_run["retry_count"] or 0
next_retry_at = last_run["next_retry_at"]
if not is_source_due(
source_type=source_type,
source_config=source_config,
last_completed_at=last_completed_at,
last_status=last_status,
retry_count=retry_count,
next_retry_at=next_retry_at,
now=now,
):
skipped_not_due += 1
continue
if not await check_rate_limit(rds, source_type, now):
logger.warning(
"Rate limit hit for macro_news, skipping %s",
src["source_name"],
)
skipped_rate_limit += 1
continue
job = build_job_payload(src, [], now)
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
enqueued += 1
logger.debug(
"Enqueued macro_news job for %s", src["source_name"],
)
logger.info(
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
enqueued, skipped_not_due, skipped_rate_limit, len(sources),
enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources),
)
return enqueued
+56
View File
@@ -110,6 +110,19 @@ BUCKET_RETENTION_FIELDS: dict[str, str] = {
}
@dataclass
class MacroConfig:
"""Configuration for the macro news interpolation layer.
Requirements: 5.6, 10.1, 10.2, 12.9
"""
macro_signal_weight: float = 0.3 # relative weight of macro vs company signals
macro_enabled: bool = True # runtime toggle state (default on)
macro_confidence_threshold: float = 0.4 # minimum confidence for event inclusion
macro_short_term_staleness_hours: int = 48 # hours after which short-term events get accelerated decay
projection_confidence_threshold: float = 0.3 # minimum confidence for projections to influence recommendations
@dataclass
class AlertingConfig:
"""Thresholds for operational alerting rules.
@@ -135,6 +148,26 @@ class AlertingConfig:
check_interval_seconds: int = 120
@dataclass
class CompetitiveConfig:
"""Configuration for the competitive intelligence & historical pattern matching layer.
Requirements: 5.6, 6.1, 9.1, 9.2, 11.2, 11.3
"""
competitive_signal_weight: float = 0.2
competitive_enabled: bool = True
pattern_confidence_threshold: float = 0.3
propagation_strength_threshold: float = 0.2
routine_lookback_days: int = 180
major_decision_lookback_days: int = 365
major_decision_weight_multiplier: float = 1.3
staleness_window_days: int = 180
staleness_recent_days: int = 90
staleness_decay_penalty: float = 0.5
min_pattern_samples: int = 3
propagation_failure_threshold: int = 5 # consecutive failures before operator alert
@dataclass
class AppConfig:
postgres: PostgresConfig = field(default_factory=PostgresConfig)
@@ -146,6 +179,8 @@ class AppConfig:
broker: BrokerConfig = field(default_factory=BrokerConfig)
retention: RetentionConfig = field(default_factory=RetentionConfig)
alerting: AlertingConfig = field(default_factory=AlertingConfig)
macro: MacroConfig = field(default_factory=MacroConfig)
competitive: CompetitiveConfig = field(default_factory=CompetitiveConfig)
log_level: str = "INFO"
json_logs: bool = True
@@ -222,6 +257,27 @@ def load_config() -> AppConfig:
broker_error_window_hours=int(os.getenv("ALERT_BROKER_ERROR_WINDOW_HOURS", "1")),
check_interval_seconds=int(os.getenv("ALERT_CHECK_INTERVAL_SECONDS", "120")),
),
macro=MacroConfig(
macro_signal_weight=float(os.getenv("MACRO_SIGNAL_WEIGHT", "0.3")),
macro_enabled=os.getenv("MACRO_ENABLED", "true").lower() == "true",
macro_confidence_threshold=float(os.getenv("MACRO_CONFIDENCE_THRESHOLD", "0.4")),
macro_short_term_staleness_hours=int(os.getenv("MACRO_SHORT_TERM_STALENESS_HOURS", "48")),
projection_confidence_threshold=float(os.getenv("PROJECTION_CONFIDENCE_THRESHOLD", "0.3")),
),
competitive=CompetitiveConfig(
competitive_signal_weight=float(os.getenv("COMPETITIVE_SIGNAL_WEIGHT", "0.2")),
competitive_enabled=os.getenv("COMPETITIVE_ENABLED", "true").lower() == "true",
pattern_confidence_threshold=float(os.getenv("COMPETITIVE_PATTERN_CONFIDENCE_THRESHOLD", "0.3")),
propagation_strength_threshold=float(os.getenv("COMPETITIVE_PROPAGATION_STRENGTH_THRESHOLD", "0.2")),
routine_lookback_days=int(os.getenv("COMPETITIVE_ROUTINE_LOOKBACK_DAYS", "180")),
major_decision_lookback_days=int(os.getenv("COMPETITIVE_MAJOR_DECISION_LOOKBACK_DAYS", "365")),
major_decision_weight_multiplier=float(os.getenv("COMPETITIVE_MAJOR_DECISION_WEIGHT_MULTIPLIER", "1.3")),
staleness_window_days=int(os.getenv("COMPETITIVE_STALENESS_WINDOW_DAYS", "180")),
staleness_recent_days=int(os.getenv("COMPETITIVE_STALENESS_RECENT_DAYS", "90")),
staleness_decay_penalty=float(os.getenv("COMPETITIVE_STALENESS_DECAY_PENALTY", "0.5")),
min_pattern_samples=int(os.getenv("COMPETITIVE_MIN_PATTERN_SAMPLES", "3")),
propagation_failure_threshold=int(os.getenv("COMPETITIVE_PROPAGATION_FAILURE_THRESHOLD", "5")),
),
log_level=os.getenv("LOG_LEVEL", "INFO"),
json_logs=os.getenv("JSON_LOGS", "true").lower() == "true",
)
+1
View File
@@ -214,6 +214,7 @@ def _resolve_document_type(source_type: str) -> str:
"news_api": "article",
"filings_api": "filing",
"web_scrape": "press_release",
"macro_news": "macro_event",
}
return mapping.get(source_type, "article")
+1
View File
@@ -64,3 +64,4 @@ QUEUE_RECOMMENDATION = "recommendation"
QUEUE_LAKE_PUBLISH = "lake_publish"
QUEUE_TRADE = "trade"
QUEUE_BROKER = "broker_orders"
QUEUE_MACRO_CLASSIFICATION = "macro_classification"
+159
View File
@@ -15,6 +15,7 @@ class DocumentType(str, Enum):
FILING = "filing"
TRANSCRIPT = "transcript"
PRESS_RELEASE = "press_release"
MACRO_EVENT = "macro_event"
class SourceType(str, Enum):
@@ -71,6 +72,37 @@ class TrendWindow(str, Enum):
NINETY_DAY = "90d"
class ImpactType(str, Enum):
SUPPLY_DISRUPTION = "supply_disruption"
DEMAND_SHIFT = "demand_shift"
COST_INCREASE = "cost_increase"
REGULATORY_PRESSURE = "regulatory_pressure"
CURRENCY_IMPACT = "currency_impact"
COMMODITY_SHOCK = "commodity_shock"
TRADE_BARRIER = "trade_barrier"
GEOPOLITICAL_RISK = "geopolitical_risk"
class SeverityLevel(str, Enum):
LOW = "low"
MODERATE = "moderate"
HIGH = "high"
CRITICAL = "critical"
class MarketPositionTier(str, Enum):
GLOBAL_LEADER = "global_leader"
MULTINATIONAL = "multinational"
REGIONAL = "regional"
DOMESTIC = "domestic"
class EstimatedDuration(str, Enum):
SHORT_TERM = "short_term"
MEDIUM_TERM = "medium_term"
LONG_TERM = "long_term"
# --- Document Intelligence ---
class CompanyImpact(BaseModel):
@@ -182,6 +214,63 @@ class Recommendation(BaseModel):
generated_at: datetime = Field(default_factory=datetime.utcnow)
# --- Global News Interpolation ---
class GlobalEventSchema(BaseModel):
event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
event_types: List[ImpactType] = Field(default_factory=list)
severity: SeverityLevel = SeverityLevel.LOW
affected_regions: List[str] = Field(default_factory=list)
affected_sectors: List[str] = Field(default_factory=list)
affected_commodities: List[str] = Field(default_factory=list)
summary: str = ""
key_facts: List[str] = Field(default_factory=list)
estimated_duration: EstimatedDuration = EstimatedDuration.SHORT_TERM
confidence: float = Field(ge=0, le=1, default=0.5)
source_document_id: str = ""
model_metadata: ModelMetadata = Field(default_factory=ModelMetadata)
created_at: datetime = Field(default_factory=datetime.utcnow)
class MacroImpactRecordSchema(BaseModel):
event_id: str = ""
company_id: str = ""
ticker: str = ""
macro_impact_score: float = Field(ge=0, le=1, default=0.0)
impact_direction: str = "neutral"
contributing_factors: List[str] = Field(default_factory=list)
confidence: float = Field(ge=0, le=1, default=0.5)
computed_at: datetime = Field(default_factory=datetime.utcnow)
class ExposureProfileSchema(BaseModel):
company_id: str = ""
geographic_revenue_mix: dict[str, float] = Field(default_factory=dict)
supply_chain_regions: List[str] = Field(default_factory=list)
key_input_commodities: List[str] = Field(default_factory=list)
regulatory_jurisdictions: List[str] = Field(default_factory=list)
market_position_tier: MarketPositionTier = MarketPositionTier.REGIONAL
export_dependency_pct: float = Field(ge=0, le=1, default=0.0)
source: str = "manual"
confidence: float = Field(ge=0, le=1, default=1.0)
version: int = 1
active: bool = True
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class TrendProjectionSchema(BaseModel):
trend_window_id: str = ""
projected_direction: TrendDirection = TrendDirection.NEUTRAL
projected_strength: float = Field(ge=0, le=1, default=0.5)
projected_confidence: float = Field(ge=0, le=1, default=0.5)
projection_horizon: str = "7d"
driving_factors: List[str] = Field(default_factory=list)
macro_contribution_pct: float = Field(ge=0, le=1, default=0.0)
diverges_from_current: bool = False
computed_at: datetime = Field(default_factory=datetime.utcnow)
# --- Document Metadata ---
class StorageRefs(BaseModel):
@@ -204,3 +293,73 @@ class DocumentMetadata(BaseModel):
language: str = "en"
content_hash: str = ""
storage_refs: StorageRefs = Field(default_factory=StorageRefs)
# --- Competitive Intelligence & Historical Patterns ---
class RelationshipType(str, Enum):
DIRECT_RIVAL = "direct_rival"
SAME_SECTOR = "same_sector"
OVERLAPPING_PRODUCTS = "overlapping_products"
SUPPLY_CHAIN_ADJACENT = "supply_chain_adjacent"
class CatalystTier(str, Enum):
MAJOR_CORPORATE_DECISION = "major_corporate_decision"
ROUTINE_SIGNAL = "routine_signal"
# Major corporate decision catalyst types (Req 11.1)
MAJOR_DECISION_CATALYSTS: frozenset[str] = frozenset({
"m_and_a",
"legal",
"restructuring",
"leadership_change",
"strategic_pivot",
"buyback",
"dividend_change",
})
class CompetitorRelationshipSchema(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
company_a_id: str = ""
company_b_id: str = ""
relationship_type: RelationshipType = RelationshipType.DIRECT_RIVAL
strength: float = Field(ge=0, le=1, default=0.5)
bidirectional: bool = True
source: str = "manual"
active: bool = True
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
class CompetitiveSignalRecordSchema(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
source_document_id: str = ""
source_ticker: str = ""
target_ticker: str = ""
catalyst_type: str = ""
pattern_confidence: float = Field(ge=0, le=1, default=0.0)
signal_direction: str = "neutral"
signal_strength: float = Field(ge=0, le=1, default=0.0)
relationship_strength: float = Field(ge=0, le=1, default=0.0)
computed_at: datetime = Field(default_factory=datetime.utcnow)
class HistoricalPatternSchema(BaseModel):
source_ticker: str = ""
target_ticker: str = ""
catalyst_type: str = ""
time_horizon: str = "7d"
sample_count: int = 0
bullish_pct: float = Field(ge=0, le=1, default=0.0)
bearish_pct: float = Field(ge=0, le=1, default=0.0)
avg_strength: float = Field(ge=0, le=1, default=0.0)
avg_time_to_resolution: float = 0.0
pattern_confidence: float = Field(ge=0, le=1, default=0.0)
data_start: Optional[datetime] = None
data_end: Optional[datetime] = None
tier: CatalystTier = CatalystTier.ROUTINE_SIGNAL
insufficient_data: bool = False
+6 -1
View File
@@ -48,6 +48,7 @@ SOURCE_BUCKET_MAP: dict[str, str] = {
"filings_api": "stonks-raw-filings",
"web_scrape": "stonks-raw-news",
"broker": "stonks-raw-market",
"macro_news": "stonks-raw-news",
}
# Map artifact type to content type and file extension
@@ -75,10 +76,14 @@ def build_artifact_path(
"""Build a MinIO object path following the design convention.
Pattern: {source_type}/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/{artifact_name}.{ext}
For macro_news sources, uses macro/ prefix instead of ticker:
macro/{yyyy}/{mm}/{dd}/{document_id}/{artifact_name}.{ext}
"""
ts = timestamp or datetime.now(timezone.utc)
# Macro sources use macro/ prefix instead of ticker (Requirement 1.1)
path_prefix = "macro" if source_type == "macro_news" else f"{source_type}/{ticker}"
return (
f"{source_type}/{ticker}/"
f"{path_prefix}/"
f"{ts.year}/{ts.month:02d}/{ts.day:02d}/"
f"{document_id}/{artifact_name}.{ext}"
)
+6
View File
@@ -12,6 +12,9 @@ from pydantic import BaseModel, field_validator
from services.shared.config import load_config
from services.shared.db import get_pg_pool
from services.shared.logging import setup_logging
from services.symbol_registry.exposure import router as exposure_router
from services.symbol_registry.competitors import router as competitors_router
from services.symbol_registry.competitor_inference import router as inference_router
config = load_config()
pool: Optional[asyncpg.Pool] = None
@@ -36,6 +39,9 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Stonks Oracle - Symbol Registry", lifespan=lifespan)
app.include_router(exposure_router)
app.include_router(competitors_router)
app.include_router(inference_router)
@app.get("/health")
@@ -0,0 +1,149 @@
"""Competitor auto-inference engine for the Symbol Registry API.
Identifies candidate competitors by sector/industry match and
document co-mention frequency, then upserts inferred relationships.
"""
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
router = APIRouter()
# --- Response Model ---
class CompetitorRelationship(BaseModel):
"""Response model for a competitor relationship."""
id: str
company_a_id: str
company_b_id: str
relationship_type: str
strength: float
bidirectional: bool
source: str
active: bool
created_at: datetime
updated_at: datetime
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
"""Convert asyncpg Record to dict with UUID→str coercion."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
async def infer_competitors(
pool: asyncpg.Pool, company_id: str
) -> list[dict[str, Any]]:
"""Infer competitor relationships based on sector/industry match and co-mentions.
1. Fetch target company's sector and industry.
2. Find other active companies with the same sector AND industry.
3. Count co-mentions in document_company_mentions for each candidate.
4. Compute strength = 0.3 * sector_match + 0.7 * normalized_co_mention_count.
5. Upsert relationships with source='inferred'.
Returns the list of upserted relationship rows.
"""
# Fetch target company
target = await pool.fetchrow(
"SELECT id, sector, industry FROM companies WHERE id = $1 AND active = TRUE",
company_id,
)
if not target:
raise HTTPException(404, "Company not found")
if target["sector"] is None or target["industry"] is None:
raise HTTPException(
400,
"Company must have both sector and industry defined for auto-inference",
)
sector = target["sector"]
industry = target["industry"]
# Find candidates: other active companies with same sector AND industry
candidates = await pool.fetch(
"""SELECT id FROM companies
WHERE sector = $1 AND industry = $2 AND active = TRUE AND id != $3""",
sector, industry, company_id,
)
if not candidates:
return []
candidate_ids = [r["id"] for r in candidates]
# Count co-mentions for each candidate
co_mention_rows = await pool.fetch(
"""SELECT dcm2.company_id AS candidate_id, COUNT(DISTINCT dcm1.document_id) AS co_count
FROM document_company_mentions dcm1
JOIN document_company_mentions dcm2
ON dcm1.document_id = dcm2.document_id
WHERE dcm1.company_id = $1
AND dcm2.company_id = ANY($2::uuid[])
GROUP BY dcm2.company_id""",
company_id, candidate_ids,
)
co_mention_map: dict[Any, int] = {}
for row in co_mention_rows:
co_mention_map[row["candidate_id"]] = row["co_count"]
# Normalize co-mention counts
max_count = max(co_mention_map.values()) if co_mention_map else 1
if max_count == 0:
max_count = 1
# Compute strength and upsert for each candidate
results: list[dict[str, Any]] = []
for cid in candidate_ids:
co_count = co_mention_map.get(cid, 0)
normalized = co_count / max_count
# sector_match is always 1.0 since we filter by sector+industry
strength = 0.3 * 1.0 + 0.7 * normalized
# Order IDs for the unique index: LEAST/GREATEST
a_id = min(company_id, str(cid), key=lambda x: x)
b_id = max(company_id, str(cid), key=lambda x: x)
row = await pool.fetchrow(
"""INSERT INTO competitor_relationships
(company_a_id, company_b_id, relationship_type, strength,
bidirectional, source)
VALUES ($1, $2, 'same_sector', $3, TRUE, 'inferred')
ON CONFLICT (LEAST(company_a_id, company_b_id), GREATEST(company_a_id, company_b_id))
WHERE active = TRUE
DO UPDATE SET strength = EXCLUDED.strength, updated_at = NOW()
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
a_id, b_id, strength,
)
results.append(_row_dict(row))
# Sort by strength descending before returning
results.sort(key=lambda r: r["strength"], reverse=True)
return results
@router.post(
"/companies/{company_id}/competitors/infer",
response_model=List[CompetitorRelationship],
)
async def infer_competitors_endpoint(company_id: str, request: Request):
"""Trigger auto-inference of competitor relationships for a company."""
pool = _get_pool(request)
return await infer_competitors(pool, company_id)
+226
View File
@@ -0,0 +1,226 @@
"""Competitor Relationship management endpoints for the Symbol Registry API."""
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
from services.shared.audit import record_audit_event
router = APIRouter()
# --- Valid values ---
VALID_RELATIONSHIP_TYPES = {"direct_rival", "same_sector", "overlapping_products", "supply_chain_adjacent"}
VALID_SOURCES = {"manual", "inferred"}
# --- Request/Response Models ---
class CompetitorRelationshipCreate(BaseModel):
"""Request body for creating a competitor relationship."""
company_b_id: str
relationship_type: str
strength: float = Field(default=0.5, ge=0, le=1)
bidirectional: bool = True
source: str = "manual"
@field_validator("relationship_type")
@classmethod
def validate_relationship_type(cls, v: str) -> str:
if v not in VALID_RELATIONSHIP_TYPES:
raise ValueError(f"relationship_type must be one of {VALID_RELATIONSHIP_TYPES}")
return v
@field_validator("source")
@classmethod
def validate_source(cls, v: str) -> str:
if v not in VALID_SOURCES:
raise ValueError(f"source must be one of {VALID_SOURCES}")
return v
class CompetitorRelationship(BaseModel):
"""Response model for a competitor relationship."""
id: str
company_a_id: str
company_b_id: str
relationship_type: str
strength: float
bidirectional: bool
source: str
active: bool
created_at: datetime
updated_at: datetime
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
"""Convert asyncpg Record to dict with UUID→str coercion."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
async def _company_exists(pool: asyncpg.Pool, company_id: str) -> bool:
"""Check if a company exists."""
return await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id) is not None
# --- Endpoints ---
@router.post("/companies/{company_id}/competitors", response_model=CompetitorRelationship, status_code=201)
async def create_competitor(company_id: str, body: CompetitorRelationshipCreate, request: Request):
"""Create a competitor relationship for a company."""
pool = _get_pool(request)
# Self-referencing check
if company_id == body.company_b_id:
raise HTTPException(400, "A company cannot be its own competitor")
# Check both companies exist
if not await _company_exists(pool, company_id):
raise HTTPException(404, "Company not found")
if not await _company_exists(pool, body.company_b_id):
raise HTTPException(404, "Competitor company not found")
try:
row = await pool.fetchrow(
"""INSERT INTO competitor_relationships
(company_a_id, company_b_id, relationship_type, strength, bidirectional, source)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
company_id, body.company_b_id, body.relationship_type,
body.strength, body.bidirectional, body.source,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, "An active competitor relationship already exists between these companies")
result = _row_dict(row)
await record_audit_event(
pool,
event_type="competitor_relationship.created",
entity_type="competitor_relationship",
entity_id=result["id"],
data={
"company_a_id": company_id,
"company_b_id": body.company_b_id,
"relationship_type": body.relationship_type,
"strength": body.strength,
"bidirectional": body.bidirectional,
"source": body.source,
},
actor="operator",
)
return result
@router.get("/companies/{company_id}/competitors", response_model=List[CompetitorRelationship])
async def list_competitors(company_id: str, request: Request):
"""List active competitor relationships for a company, ordered by strength descending."""
pool = _get_pool(request)
if not await _company_exists(pool, company_id):
raise HTTPException(404, "Company not found")
rows = await pool.fetch(
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at
FROM competitor_relationships
WHERE (company_a_id = $1 OR company_b_id = $1) AND active = TRUE
ORDER BY strength DESC""",
company_id,
)
return [_row_dict(r) for r in rows]
@router.put("/companies/{company_id}/competitors/{relationship_id}", response_model=CompetitorRelationship)
async def update_competitor(company_id: str, relationship_id: str, body: CompetitorRelationshipCreate, request: Request):
"""Update a competitor relationship with audit event recording previous state."""
pool = _get_pool(request)
# Fetch existing relationship
existing = await pool.fetchrow(
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at
FROM competitor_relationships
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2)""",
relationship_id, company_id,
)
if not existing:
raise HTTPException(404, "Competitor relationship not found")
previous_state = _row_dict(existing)
row = await pool.fetchrow(
"""UPDATE competitor_relationships
SET relationship_type = $2, strength = $3, bidirectional = $4, source = $5, updated_at = NOW()
WHERE id = $1
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
relationship_id, body.relationship_type, body.strength, body.bidirectional, body.source,
)
result = _row_dict(row)
await record_audit_event(
pool,
event_type="competitor_relationship.updated",
entity_type="competitor_relationship",
entity_id=result["id"],
data={
"previous_state": {
"relationship_type": previous_state["relationship_type"],
"strength": previous_state["strength"],
"bidirectional": previous_state["bidirectional"],
"source": previous_state["source"],
},
"new_state": {
"relationship_type": body.relationship_type,
"strength": body.strength,
"bidirectional": body.bidirectional,
"source": body.source,
},
},
actor="operator",
)
return result
@router.delete("/companies/{company_id}/competitors/{relationship_id}", status_code=200)
async def delete_competitor(company_id: str, relationship_id: str, request: Request):
"""Soft-delete a competitor relationship (set active=False), preserve row."""
pool = _get_pool(request)
row = await pool.fetchrow(
"""UPDATE competitor_relationships
SET active = FALSE, updated_at = NOW()
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2) AND active = TRUE
RETURNING id""",
relationship_id, company_id,
)
if not row:
raise HTTPException(404, "Active competitor relationship not found")
await record_audit_event(
pool,
event_type="competitor_relationship.deleted",
entity_type="competitor_relationship",
entity_id=str(row["id"]),
data={"company_id": company_id, "soft_deleted": True},
actor="operator",
)
return {"status": "deleted", "id": str(row["id"])}
+183
View File
@@ -0,0 +1,183 @@
"""Exposure Profile management endpoints for the Symbol Registry API."""
import json
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
router = APIRouter()
# --- Valid values ---
VALID_MARKET_POSITION_TIERS = {"global_leader", "multinational", "regional", "domestic"}
VALID_SOURCES = {"manual", "inferred"}
# --- Request/Response Models ---
class ExposureProfileCreate(BaseModel):
"""Request body for creating/updating an exposure profile."""
geographic_revenue_mix: dict[str, float] = Field(default_factory=dict)
supply_chain_regions: List[str] = Field(default_factory=list)
key_input_commodities: List[str] = Field(default_factory=list)
regulatory_jurisdictions: List[str] = Field(default_factory=list)
market_position_tier: str = "regional"
export_dependency_pct: float = 0.0
source: str = "manual"
confidence: float = 1.0
@field_validator("market_position_tier")
@classmethod
def validate_tier(cls, v: str) -> str:
if v not in VALID_MARKET_POSITION_TIERS:
raise ValueError(f"market_position_tier must be one of {VALID_MARKET_POSITION_TIERS}")
return v
@field_validator("source")
@classmethod
def validate_source(cls, v: str) -> str:
if v not in VALID_SOURCES:
raise ValueError(f"source must be one of {VALID_SOURCES}")
return v
@field_validator("export_dependency_pct", "confidence")
@classmethod
def validate_pct(cls, v: float) -> float:
if not 0.0 <= v <= 1.0:
raise ValueError("Value must be between 0.0 and 1.0")
return v
class ExposureProfileResponse(BaseModel):
"""Response model for an exposure profile."""
id: str
company_id: str
geographic_revenue_mix: dict[str, float]
supply_chain_regions: List[str]
key_input_commodities: List[str]
regulatory_jurisdictions: List[str]
market_position_tier: str
export_dependency_pct: float
source: str
confidence: float
version: int
active: bool
created_at: datetime
updated_at: datetime
def _row_to_profile(row: asyncpg.Record) -> dict[str, Any]:
"""Convert an asyncpg Record to a profile response dict."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
# geographic_revenue_mix is stored as JSONB string, parse if needed
if isinstance(d.get("geographic_revenue_mix"), str):
d["geographic_revenue_mix"] = json.loads(d["geographic_revenue_mix"])
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
# --- Endpoints ---
@router.get("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
async def get_exposure_profile(company_id: str, request: Request):
"""Get the current active exposure profile for a company."""
pool = _get_pool(request)
row = await pool.fetchrow(
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at
FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC
LIMIT 1""",
company_id,
)
if not row:
raise HTTPException(404, "No active exposure profile found for this company")
return _row_to_profile(row)
@router.put("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
async def upsert_exposure_profile(company_id: str, body: ExposureProfileCreate, request: Request):
"""Create or update an exposure profile. Archives the previous active version."""
pool = _get_pool(request)
# Verify company exists
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
if not exists:
raise HTTPException(404, "Company not found")
async with pool.acquire() as conn:
async with conn.transaction():
# Fetch current active profile to get latest version
current = await conn.fetchrow(
"""SELECT version FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC LIMIT 1""",
company_id,
)
if current:
new_version = current["version"] + 1
# Archive the current active profile
await conn.execute(
"""UPDATE exposure_profiles
SET active = FALSE, updated_at = NOW()
WHERE company_id = $1 AND active = TRUE""",
company_id,
)
else:
new_version = 1
# Insert new profile
row = await conn.fetchrow(
"""INSERT INTO exposure_profiles
(company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
RETURNING id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at""",
company_id,
json.dumps(body.geographic_revenue_mix),
body.supply_chain_regions,
body.key_input_commodities,
body.regulatory_jurisdictions,
body.market_position_tier,
body.export_dependency_pct,
body.source,
body.confidence,
new_version,
)
return _row_to_profile(row)
@router.get("/companies/{company_id}/exposure/history", response_model=List[ExposureProfileResponse])
async def get_exposure_history(company_id: str, request: Request):
"""Get all exposure profile versions for a company, ordered by version descending."""
pool = _get_pool(request)
rows = await pool.fetch(
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at
FROM exposure_profiles
WHERE company_id = $1
ORDER BY version DESC""",
company_id,
)
return [_row_to_profile(r) for r in rows]