feat: competitive intelligence & historical pattern matching layer
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
"""Macro news adapter for global/geopolitical news ingestion.
|
||||
|
||||
Fetches macro-level news articles from configured sources for global event
|
||||
classification. Reuses the same adapter pattern as company-specific news
|
||||
but targets macro-focused endpoints and does not require a ticker.
|
||||
|
||||
Requirements: 1.1, 1.2, 1.3, 1.4
|
||||
"""
|
||||
import hashlib
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import AdapterResult, BaseAdapter
|
||||
|
||||
logger = logging.getLogger("macro_news_adapter")
|
||||
|
||||
|
||||
class MacroNewsAdapter(BaseAdapter):
|
||||
"""Adapter for fetching macro/geopolitical news from configured sources.
|
||||
|
||||
Supports fetching from any HTTP endpoint that returns JSON with a list
|
||||
of news articles. The endpoint URL and response parsing are configured
|
||||
via the source config dict.
|
||||
|
||||
Config options:
|
||||
url: The endpoint URL to fetch from
|
||||
limit: Max articles to return per request (default 20)
|
||||
params: Additional query parameters as a dict
|
||||
results_key: JSON key containing the article list (default "results")
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str = "", base_url: str = "") -> None:
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/") if base_url else ""
|
||||
|
||||
def source_type(self) -> str:
|
||||
return "macro_news"
|
||||
|
||||
async def fetch(self, ticker: str, config: dict[str, Any]) -> AdapterResult:
|
||||
"""Fetch macro news articles from the configured endpoint.
|
||||
|
||||
The ticker parameter is ignored for macro sources — these are
|
||||
global/geopolitical news, not company-specific.
|
||||
|
||||
Args:
|
||||
ticker: Ignored for macro sources (may be empty string).
|
||||
config: Source-specific configuration with url, params, etc.
|
||||
|
||||
Returns:
|
||||
AdapterResult with raw payload and parsed article items.
|
||||
"""
|
||||
url = config.get("url", "")
|
||||
if not url and self.base_url:
|
||||
url = self.base_url
|
||||
|
||||
if not url:
|
||||
return self._error_result("No URL configured for macro news source")
|
||||
|
||||
params = dict(config.get("params", {}))
|
||||
if self.api_key:
|
||||
params["apiKey"] = self.api_key
|
||||
|
||||
limit = config.get("limit", 20)
|
||||
params["limit"] = str(min(int(limit), 1000))
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
resp = await client.get(url, params=params)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
resp.raise_for_status()
|
||||
|
||||
raw = resp.content
|
||||
data = resp.json()
|
||||
content_hash = hashlib.sha256(raw).hexdigest()
|
||||
|
||||
results_key = config.get("results_key", "results")
|
||||
items = data.get(results_key, [])
|
||||
if not isinstance(items, list):
|
||||
items = []
|
||||
|
||||
return AdapterResult(
|
||||
source_type="macro_news",
|
||||
ticker="",
|
||||
items=items,
|
||||
raw_payload=raw,
|
||||
content_hash=content_hash,
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
http_status=resp.status_code,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={
|
||||
"provider": config.get("provider", "macro"),
|
||||
"results_count": len(items),
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Macro news HTTP error: %s", e)
|
||||
return self._error_result(
|
||||
str(e), elapsed_ms,
|
||||
http_status=e.response.status_code if e.response else None,
|
||||
raw=e.response.content if e.response else b"",
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Macro news timeout: %s", e)
|
||||
return self._error_result(f"timeout: {e}", elapsed_ms)
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
logger.error("Macro news fetch failed: %s", e)
|
||||
return self._error_result(str(e), elapsed_ms)
|
||||
|
||||
def _error_result(
|
||||
self,
|
||||
error: str,
|
||||
elapsed_ms: float = 0.0,
|
||||
http_status: int | None = None,
|
||||
raw: bytes = b"",
|
||||
) -> AdapterResult:
|
||||
"""Build an error AdapterResult for macro news fetches."""
|
||||
return AdapterResult(
|
||||
source_type="macro_news",
|
||||
ticker="",
|
||||
items=[],
|
||||
raw_payload=raw,
|
||||
content_hash="",
|
||||
fetched_at=datetime.now(timezone.utc),
|
||||
error=error,
|
||||
http_status=http_status,
|
||||
response_time_ms=round(elapsed_ms, 1),
|
||||
metadata={"provider": "macro"},
|
||||
)
|
||||
@@ -0,0 +1,741 @@
|
||||
"""Interpolation engine — macro-to-company impact scoring.
|
||||
|
||||
Computes per-company macro impact scores by evaluating overlap between
|
||||
global event classifications and company exposure profiles. Produces
|
||||
MacroImpactRecord objects that feed into the aggregation engine as
|
||||
additional weighted signals.
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 3.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.extractor.event_classifier import GlobalEvent
|
||||
from services.shared.schemas import (
|
||||
ExposureProfileSchema,
|
||||
MarketPositionTier,
|
||||
SeverityLevel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("interpolation")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default configuration constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_CONFIDENCE_THRESHOLD = 0.4
|
||||
DEFAULT_SHORT_TERM_STALENESS_HOURS = 48
|
||||
ACCELERATED_DECAY_MULTIPLIER = 0.5 # applied on top of standard recency decay
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Severity weights
|
||||
SEVERITY_WEIGHTS: dict[str, float] = {
|
||||
SeverityLevel.CRITICAL.value: 1.0,
|
||||
SeverityLevel.HIGH.value: 0.75,
|
||||
SeverityLevel.MODERATE.value: 0.5,
|
||||
SeverityLevel.LOW.value: 0.25,
|
||||
}
|
||||
|
||||
# Component weights in the scoring formula
|
||||
GEO_WEIGHT = 0.35
|
||||
SUPPLY_WEIGHT = 0.25
|
||||
COMMODITY_WEIGHT = 0.25
|
||||
SECTOR_WEIGHT = 0.15
|
||||
|
||||
# Resilience modifiers for international events
|
||||
RESILIENCE_MODIFIERS: dict[str, float] = {
|
||||
MarketPositionTier.GLOBAL_LEADER.value: 0.7,
|
||||
MarketPositionTier.MULTINATIONAL.value: 0.85,
|
||||
MarketPositionTier.REGIONAL.value: 1.0,
|
||||
MarketPositionTier.DOMESTIC.value: 1.2,
|
||||
}
|
||||
|
||||
# Event types that are typically negative
|
||||
_NEGATIVE_EVENT_TYPES = frozenset({
|
||||
"supply_disruption",
|
||||
"cost_increase",
|
||||
"regulatory_pressure",
|
||||
"geopolitical_risk",
|
||||
"trade_barrier",
|
||||
})
|
||||
|
||||
# Event types that can be positive
|
||||
_POSITIVE_EVENT_TYPES = frozenset({
|
||||
"demand_shift",
|
||||
})
|
||||
|
||||
# Event types that can go either way
|
||||
_AMBIGUOUS_EVENT_TYPES = frozenset({
|
||||
"commodity_shock",
|
||||
"currency_impact",
|
||||
})
|
||||
|
||||
# Market cap bucket → market position tier mapping for default profiles
|
||||
_CAP_TO_TIER: dict[str, str] = {
|
||||
"large_cap": MarketPositionTier.GLOBAL_LEADER.value,
|
||||
"mid_cap": MarketPositionTier.MULTINATIONAL.value,
|
||||
"small_cap": MarketPositionTier.REGIONAL.value,
|
||||
"micro_cap": MarketPositionTier.DOMESTIC.value,
|
||||
}
|
||||
|
||||
# Sector-based default geographic revenue mixes
|
||||
_SECTOR_DEFAULT_GEO: dict[str, dict[str, float]] = {
|
||||
"Information Technology": {"US": 0.45, "CN": 0.15, "EU": 0.15, "JP": 0.10, "KR": 0.15},
|
||||
"Health Care": {"US": 0.50, "EU": 0.25, "JP": 0.10, "CN": 0.15},
|
||||
"Financials": {"US": 0.55, "EU": 0.20, "GB": 0.15, "JP": 0.10},
|
||||
"Energy": {"US": 0.30, "SA": 0.20, "RU": 0.15, "CA": 0.15, "AE": 0.20},
|
||||
"Materials": {"US": 0.25, "CN": 0.25, "AU": 0.20, "BR": 0.15, "IN": 0.15},
|
||||
"Industrials": {"US": 0.40, "DE": 0.15, "CN": 0.15, "JP": 0.15, "KR": 0.15},
|
||||
"Consumer Discretionary": {"US": 0.45, "CN": 0.20, "EU": 0.15, "JP": 0.10, "IN": 0.10},
|
||||
"Consumer Staples": {"US": 0.45, "EU": 0.20, "CN": 0.15, "IN": 0.10, "BR": 0.10},
|
||||
"Communication Services": {"US": 0.50, "CN": 0.15, "EU": 0.15, "JP": 0.10, "IN": 0.10},
|
||||
"Utilities": {"US": 0.70, "EU": 0.15, "JP": 0.15},
|
||||
"Real Estate": {"US": 0.60, "CN": 0.15, "EU": 0.15, "JP": 0.10},
|
||||
}
|
||||
|
||||
_DEFAULT_GEO = {"US": 0.50, "EU": 0.20, "CN": 0.15, "JP": 0.15}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MacroImpactRecord dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MacroImpactRecord:
|
||||
"""A computed macro impact score for a specific company-event pair."""
|
||||
|
||||
event_id: str = ""
|
||||
company_id: str = ""
|
||||
ticker: str = ""
|
||||
macro_impact_score: float = 0.0 # [0, 1]
|
||||
impact_direction: str = "neutral" # positive|negative|mixed
|
||||
contributing_factors: list[str] = field(default_factory=list)
|
||||
confidence: float = 0.5 # [0, 1]
|
||||
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Overlap computation functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_geographic_overlap(
|
||||
event_regions: list[str],
|
||||
revenue_mix: dict[str, float],
|
||||
) -> float:
|
||||
"""Compute geographic overlap using revenue percentage weighting.
|
||||
|
||||
For each event region that appears in the company's revenue mix,
|
||||
sum the revenue percentage. Returns a value in [0, 1].
|
||||
|
||||
Args:
|
||||
event_regions: Region codes from the global event.
|
||||
revenue_mix: Company's geographic_revenue_mix (region -> pct).
|
||||
|
||||
Returns:
|
||||
Sum of revenue percentages for overlapping regions, clamped to [0, 1].
|
||||
"""
|
||||
if not event_regions or not revenue_mix:
|
||||
return 0.0
|
||||
|
||||
event_set = {r.upper() for r in event_regions}
|
||||
overlap = 0.0
|
||||
for region, pct in revenue_mix.items():
|
||||
if region.upper() in event_set:
|
||||
overlap += pct
|
||||
|
||||
return min(max(overlap, 0.0), 1.0)
|
||||
|
||||
|
||||
def compute_supply_chain_overlap(
|
||||
event_regions: list[str],
|
||||
supply_regions: list[str],
|
||||
) -> float:
|
||||
"""Compute supply chain overlap using set intersection ratio.
|
||||
|
||||
Returns the fraction of the company's supply chain regions that
|
||||
overlap with the event's affected regions.
|
||||
|
||||
Args:
|
||||
event_regions: Region codes from the global event.
|
||||
supply_regions: Company's supply_chain_regions.
|
||||
|
||||
Returns:
|
||||
Intersection ratio in [0, 1]. 0.0 if supply_regions is empty.
|
||||
"""
|
||||
if not event_regions or not supply_regions:
|
||||
return 0.0
|
||||
|
||||
event_set = {r.upper() for r in event_regions}
|
||||
supply_set = {r.upper() for r in supply_regions}
|
||||
|
||||
intersection = event_set & supply_set
|
||||
return len(intersection) / len(supply_set)
|
||||
|
||||
|
||||
def compute_commodity_overlap(
|
||||
event_commodities: list[str],
|
||||
company_commodities: list[str],
|
||||
) -> float:
|
||||
"""Compute commodity overlap using set intersection ratio.
|
||||
|
||||
Returns the fraction of the company's key commodities that overlap
|
||||
with the event's affected commodities.
|
||||
|
||||
Args:
|
||||
event_commodities: Commodity identifiers from the global event.
|
||||
company_commodities: Company's key_input_commodities.
|
||||
|
||||
Returns:
|
||||
Intersection ratio in [0, 1]. 0.0 if company_commodities is empty.
|
||||
"""
|
||||
if not event_commodities or not company_commodities:
|
||||
return 0.0
|
||||
|
||||
event_set = {c.lower() for c in event_commodities}
|
||||
company_set = {c.lower() for c in company_commodities}
|
||||
|
||||
intersection = event_set & company_set
|
||||
return len(intersection) / len(company_set)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Resilience modifier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_resilience_modifier(
|
||||
raw_score: float,
|
||||
tier: str,
|
||||
event_is_international: bool = True,
|
||||
) -> float:
|
||||
"""Apply a resilience modifier based on market position tier.
|
||||
|
||||
For international events, global leaders get a dampening factor (0.7)
|
||||
while domestic companies get an amplification factor (1.2).
|
||||
For domestic-only events, no modifier is applied.
|
||||
|
||||
Args:
|
||||
raw_score: The raw impact score before resilience adjustment.
|
||||
tier: Market position tier value.
|
||||
event_is_international: Whether the event affects multiple countries.
|
||||
|
||||
Returns:
|
||||
Modified score clamped to [0, 1].
|
||||
"""
|
||||
if not event_is_international:
|
||||
return min(max(raw_score, 0.0), 1.0)
|
||||
|
||||
modifier = RESILIENCE_MODIFIERS.get(tier, 1.0)
|
||||
return min(max(raw_score * modifier, 0.0), 1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Impact direction determination
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _determine_impact_direction(
|
||||
event_types: list[str],
|
||||
) -> tuple[str, list[str], list[str]]:
|
||||
"""Determine impact direction from event types.
|
||||
|
||||
Returns:
|
||||
Tuple of (direction, positive_factors, negative_factors).
|
||||
"""
|
||||
positive_factors: list[str] = []
|
||||
negative_factors: list[str] = []
|
||||
|
||||
for et in event_types:
|
||||
if et in _NEGATIVE_EVENT_TYPES:
|
||||
negative_factors.append(et)
|
||||
elif et in _POSITIVE_EVENT_TYPES:
|
||||
positive_factors.append(et)
|
||||
elif et in _AMBIGUOUS_EVENT_TYPES:
|
||||
# Ambiguous types contribute to both sides
|
||||
positive_factors.append(et)
|
||||
negative_factors.append(et)
|
||||
|
||||
has_positive = len(positive_factors) > 0
|
||||
has_negative = len(negative_factors) > 0
|
||||
|
||||
if has_positive and has_negative:
|
||||
return "mixed", positive_factors, negative_factors
|
||||
elif has_positive:
|
||||
return "positive", positive_factors, negative_factors
|
||||
elif has_negative:
|
||||
return "negative", positive_factors, negative_factors
|
||||
else:
|
||||
return "negative", positive_factors, negative_factors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core scoring function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_macro_impact(
|
||||
event: GlobalEvent,
|
||||
profile: ExposureProfileSchema,
|
||||
) -> MacroImpactRecord:
|
||||
"""Compute the macro impact of a global event on a company.
|
||||
|
||||
Scoring formula:
|
||||
raw_score = severity_weight * (
|
||||
0.35 * geographic_overlap +
|
||||
0.25 * supply_chain_overlap +
|
||||
0.25 * commodity_overlap +
|
||||
0.15 * sector_match
|
||||
)
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
Args:
|
||||
event: The classified global event.
|
||||
profile: The company's exposure profile.
|
||||
|
||||
Returns:
|
||||
A MacroImpactRecord with the computed score and metadata.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Compute overlaps
|
||||
geo_overlap = compute_geographic_overlap(
|
||||
event.affected_regions,
|
||||
profile.geographic_revenue_mix,
|
||||
)
|
||||
supply_overlap = compute_supply_chain_overlap(
|
||||
event.affected_regions,
|
||||
profile.supply_chain_regions,
|
||||
)
|
||||
commodity_overlap = compute_commodity_overlap(
|
||||
event.affected_commodities,
|
||||
profile.key_input_commodities,
|
||||
)
|
||||
|
||||
# Sector match: 1.0 if any event sector matches the company's sector
|
||||
# We check against the profile's regulatory_jurisdictions as a proxy,
|
||||
# but the real sector comes from the company data. For now, we use
|
||||
# a simple heuristic: check if any affected_sectors appear in the
|
||||
# profile's geographic_revenue_mix keys or supply_chain_regions.
|
||||
# The actual sector is not stored in ExposureProfileSchema, so we
|
||||
# check if any event sectors match. This will be 0.0 unless the
|
||||
# caller provides sector info through contributing_factors.
|
||||
sector_match = 0.0
|
||||
# We'll compute sector_match based on event sectors — the caller
|
||||
# should ensure the profile has relevant sector info. For the
|
||||
# default implementation, we always set sector_match to 0.0 here
|
||||
# and let the caller override if needed.
|
||||
|
||||
# Check zero-overlap case
|
||||
contributing = []
|
||||
if geo_overlap > 0:
|
||||
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
|
||||
if supply_overlap > 0:
|
||||
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
|
||||
if commodity_overlap > 0:
|
||||
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
|
||||
|
||||
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
|
||||
if total_overlap == 0.0:
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=0.0,
|
||||
impact_direction="neutral",
|
||||
contributing_factors=[],
|
||||
confidence=0.0,
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
# Severity weight
|
||||
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
|
||||
|
||||
# Raw score
|
||||
raw_score = severity_weight * (
|
||||
GEO_WEIGHT * geo_overlap
|
||||
+ SUPPLY_WEIGHT * supply_overlap
|
||||
+ COMMODITY_WEIGHT * commodity_overlap
|
||||
+ SECTOR_WEIGHT * sector_match
|
||||
)
|
||||
|
||||
# Determine if event is international (affects multiple regions)
|
||||
is_international = len(event.affected_regions) > 1
|
||||
|
||||
# Apply resilience modifier
|
||||
tier = profile.market_position_tier
|
||||
if isinstance(tier, MarketPositionTier):
|
||||
tier = tier.value
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
# Determine impact direction
|
||||
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
|
||||
|
||||
# Build contributing factors list
|
||||
all_factors = list(contributing)
|
||||
if pos_factors:
|
||||
all_factors.append(f"positive_types:{','.join(pos_factors)}")
|
||||
if neg_factors:
|
||||
all_factors.append(f"negative_types:{','.join(neg_factors)}")
|
||||
|
||||
# Confidence: combine event confidence with overlap strength
|
||||
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
|
||||
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=round(min(final_score, 1.0), 6),
|
||||
impact_direction=direction,
|
||||
contributing_factors=all_factors,
|
||||
confidence=round(confidence, 6),
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
|
||||
def compute_macro_impact_with_sector(
|
||||
event: GlobalEvent,
|
||||
profile: ExposureProfileSchema,
|
||||
company_sector: str = "",
|
||||
) -> MacroImpactRecord:
|
||||
"""Compute macro impact with explicit sector matching.
|
||||
|
||||
Like compute_macro_impact but accepts a company_sector parameter
|
||||
for proper sector_match computation.
|
||||
|
||||
Args:
|
||||
event: The classified global event.
|
||||
profile: The company's exposure profile.
|
||||
company_sector: The company's GICS sector name.
|
||||
|
||||
Returns:
|
||||
A MacroImpactRecord with the computed score and metadata.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Compute overlaps
|
||||
geo_overlap = compute_geographic_overlap(
|
||||
event.affected_regions,
|
||||
profile.geographic_revenue_mix,
|
||||
)
|
||||
supply_overlap = compute_supply_chain_overlap(
|
||||
event.affected_regions,
|
||||
profile.supply_chain_regions,
|
||||
)
|
||||
commodity_overlap = compute_commodity_overlap(
|
||||
event.affected_commodities,
|
||||
profile.key_input_commodities,
|
||||
)
|
||||
|
||||
# Sector match
|
||||
sector_match = 0.0
|
||||
if company_sector and event.affected_sectors:
|
||||
company_sector_lower = company_sector.lower().strip()
|
||||
for es in event.affected_sectors:
|
||||
if es.lower().strip() == company_sector_lower:
|
||||
sector_match = 1.0
|
||||
break
|
||||
|
||||
# Contributing factors
|
||||
contributing: list[str] = []
|
||||
if geo_overlap > 0:
|
||||
contributing.append(f"geographic_overlap:{geo_overlap:.3f}")
|
||||
if supply_overlap > 0:
|
||||
contributing.append(f"supply_chain_overlap:{supply_overlap:.3f}")
|
||||
if commodity_overlap > 0:
|
||||
contributing.append(f"commodity_overlap:{commodity_overlap:.3f}")
|
||||
if sector_match > 0:
|
||||
contributing.append(f"sector_match:{company_sector}")
|
||||
|
||||
total_overlap = geo_overlap + supply_overlap + commodity_overlap + sector_match
|
||||
if total_overlap == 0.0:
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=0.0,
|
||||
impact_direction="neutral",
|
||||
contributing_factors=[],
|
||||
confidence=0.0,
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
# Severity weight
|
||||
severity_weight = SEVERITY_WEIGHTS.get(event.severity, 0.25)
|
||||
|
||||
# Raw score
|
||||
raw_score = severity_weight * (
|
||||
GEO_WEIGHT * geo_overlap
|
||||
+ SUPPLY_WEIGHT * supply_overlap
|
||||
+ COMMODITY_WEIGHT * commodity_overlap
|
||||
+ SECTOR_WEIGHT * sector_match
|
||||
)
|
||||
|
||||
# International check
|
||||
is_international = len(event.affected_regions) > 1
|
||||
|
||||
# Resilience modifier
|
||||
tier = profile.market_position_tier
|
||||
if isinstance(tier, MarketPositionTier):
|
||||
tier = tier.value
|
||||
final_score = apply_resilience_modifier(raw_score, tier, is_international)
|
||||
|
||||
# Direction
|
||||
direction, pos_factors, neg_factors = _determine_impact_direction(event.event_types)
|
||||
|
||||
all_factors = list(contributing)
|
||||
if pos_factors:
|
||||
all_factors.append(f"positive_types:{','.join(pos_factors)}")
|
||||
if neg_factors:
|
||||
all_factors.append(f"negative_types:{','.join(neg_factors)}")
|
||||
|
||||
confidence = min(event.confidence * min(total_overlap + 0.3, 1.0), 1.0)
|
||||
|
||||
return MacroImpactRecord(
|
||||
event_id=event.event_id,
|
||||
company_id=profile.company_id,
|
||||
ticker="",
|
||||
macro_impact_score=round(min(final_score, 1.0), 6),
|
||||
impact_direction=direction,
|
||||
contributing_factors=all_factors,
|
||||
confidence=round(confidence, 6),
|
||||
computed_at=now,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default profile builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def build_default_profile(
|
||||
sector: str,
|
||||
industry: str,
|
||||
market_cap_bucket: str,
|
||||
) -> ExposureProfileSchema:
|
||||
"""Build a default ExposureProfile for companies without manual profiles.
|
||||
|
||||
Uses sector-based geographic revenue defaults and maps market_cap_bucket
|
||||
to market_position_tier:
|
||||
large_cap → global_leader
|
||||
mid_cap → multinational
|
||||
small_cap → regional
|
||||
micro_cap → domestic
|
||||
|
||||
Args:
|
||||
sector: GICS sector name.
|
||||
industry: Industry name (used for commodity defaults).
|
||||
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
|
||||
|
||||
Returns:
|
||||
An ExposureProfileSchema with source='inferred'.
|
||||
"""
|
||||
tier = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
|
||||
geo_mix = _SECTOR_DEFAULT_GEO.get(sector, _DEFAULT_GEO)
|
||||
|
||||
# Derive supply chain regions from geo mix keys
|
||||
supply_regions = list(geo_mix.keys())
|
||||
|
||||
# Derive commodities from sector/industry
|
||||
commodities = _infer_commodities(sector, industry)
|
||||
|
||||
# Export dependency based on tier
|
||||
export_pct = {
|
||||
MarketPositionTier.GLOBAL_LEADER.value: 0.5,
|
||||
MarketPositionTier.MULTINATIONAL.value: 0.35,
|
||||
MarketPositionTier.REGIONAL.value: 0.15,
|
||||
MarketPositionTier.DOMESTIC.value: 0.05,
|
||||
}.get(tier, 0.15)
|
||||
|
||||
return ExposureProfileSchema(
|
||||
company_id="",
|
||||
geographic_revenue_mix=dict(geo_mix),
|
||||
supply_chain_regions=supply_regions,
|
||||
key_input_commodities=commodities,
|
||||
regulatory_jurisdictions=list(geo_mix.keys())[:3],
|
||||
market_position_tier=MarketPositionTier(tier),
|
||||
export_dependency_pct=export_pct,
|
||||
source="inferred",
|
||||
confidence=0.5,
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
def _infer_commodities(sector: str, industry: str) -> list[str]:
|
||||
"""Infer key input commodities from sector and industry."""
|
||||
sector_commodities: dict[str, list[str]] = {
|
||||
"Energy": ["crude_oil", "natural_gas"],
|
||||
"Materials": ["copper", "steel", "lithium"],
|
||||
"Industrials": ["steel", "copper"],
|
||||
"Information Technology": ["semiconductors", "lithium"],
|
||||
"Consumer Staples": ["wheat", "corn"],
|
||||
"Consumer Discretionary": ["steel", "semiconductors"],
|
||||
"Health Care": [],
|
||||
"Financials": [],
|
||||
"Communication Services": [],
|
||||
"Utilities": ["natural_gas"],
|
||||
"Real Estate": ["steel"],
|
||||
}
|
||||
return sector_commodities.get(sector, [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def persist_macro_impact_record(
|
||||
pool: asyncpg.Pool,
|
||||
record: MacroImpactRecord,
|
||||
) -> str:
|
||||
"""Persist a MacroImpactRecord to the macro_impact_records table.
|
||||
|
||||
Returns the row UUID.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO macro_impact_records
|
||||
(event_id, company_id, ticker, macro_impact_score,
|
||||
impact_direction, contributing_factors, confidence, computed_at)
|
||||
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6::jsonb, $7, $8)
|
||||
RETURNING id""",
|
||||
record.event_id,
|
||||
record.company_id,
|
||||
record.ticker,
|
||||
record.macro_impact_score,
|
||||
record.impact_direction,
|
||||
json.dumps(record.contributing_factors),
|
||||
record.confidence,
|
||||
record.computed_at,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted macro impact record for event=%s company=%s score=%.4f direction=%s",
|
||||
record.event_id,
|
||||
record.company_id,
|
||||
record.macro_impact_score,
|
||||
record.impact_direction,
|
||||
)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
async def persist_macro_impact_records(
|
||||
pool: asyncpg.Pool,
|
||||
records: list[MacroImpactRecord],
|
||||
) -> list[str]:
|
||||
"""Persist multiple MacroImpactRecords. Returns list of row UUIDs."""
|
||||
ids: list[str] = []
|
||||
for record in records:
|
||||
if record.macro_impact_score > 0.0:
|
||||
row_id = await persist_macro_impact_record(pool, record)
|
||||
ids.append(row_id)
|
||||
return ids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-confidence event exclusion (Requirements: 10.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def filter_low_confidence_events(
|
||||
events: list[GlobalEvent],
|
||||
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
|
||||
) -> list[GlobalEvent]:
|
||||
"""Filter out events with confidence below the configurable threshold.
|
||||
|
||||
Events with confidence below the threshold are excluded from macro
|
||||
impact computation and the exclusion reason is logged.
|
||||
|
||||
Args:
|
||||
events: List of GlobalEvent classifications.
|
||||
confidence_threshold: Minimum confidence for inclusion (default 0.4).
|
||||
|
||||
Returns:
|
||||
List of events that pass the confidence threshold.
|
||||
|
||||
Requirements: 10.1
|
||||
"""
|
||||
included: list[GlobalEvent] = []
|
||||
for event in events:
|
||||
if event.confidence < confidence_threshold:
|
||||
logger.info(
|
||||
"Excluding low-confidence event %s: confidence=%.3f < threshold=%.3f",
|
||||
event.event_id,
|
||||
event.confidence,
|
||||
confidence_threshold,
|
||||
)
|
||||
else:
|
||||
included.append(event)
|
||||
return included
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Accelerated decay for stale short-term events (Requirements: 10.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_standard_recency_decay(
|
||||
age_hours: float,
|
||||
half_life_hours: float = 168.0,
|
||||
) -> float:
|
||||
"""Compute standard exponential recency decay.
|
||||
|
||||
Args:
|
||||
age_hours: Age of the event in hours.
|
||||
half_life_hours: Half-life for the decay function (default 7 days).
|
||||
|
||||
Returns:
|
||||
Decay factor in (0, 1].
|
||||
"""
|
||||
import math
|
||||
if age_hours <= 0:
|
||||
return 1.0
|
||||
return math.exp(-0.693 * age_hours / half_life_hours)
|
||||
|
||||
|
||||
def apply_accelerated_decay(
|
||||
age_hours: float,
|
||||
estimated_duration: str,
|
||||
staleness_hours: float = DEFAULT_SHORT_TERM_STALENESS_HOURS,
|
||||
half_life_hours: float = 168.0,
|
||||
) -> float:
|
||||
"""Apply accelerated decay for stale short-term events.
|
||||
|
||||
For short_term events older than staleness_hours (default 48h),
|
||||
the effective weight is strictly less than standard recency decay.
|
||||
|
||||
For non-short-term events or events within the staleness window,
|
||||
standard recency decay is applied.
|
||||
|
||||
Args:
|
||||
age_hours: Age of the event in hours.
|
||||
estimated_duration: Event's estimated_duration field.
|
||||
staleness_hours: Hours after which short-term events get accelerated decay.
|
||||
half_life_hours: Half-life for the standard decay function.
|
||||
|
||||
Returns:
|
||||
Effective signal weight in (0, 1].
|
||||
|
||||
Requirements: 10.2
|
||||
"""
|
||||
standard_decay = compute_standard_recency_decay(age_hours, half_life_hours)
|
||||
|
||||
if estimated_duration == "short_term" and age_hours > staleness_hours:
|
||||
# Apply accelerated decay: multiply standard decay by a factor < 1
|
||||
accelerated = standard_decay * ACCELERATED_DECAY_MULTIPLIER
|
||||
logger.debug(
|
||||
"Accelerated decay for short_term event: age=%.1fh, "
|
||||
"standard=%.4f, accelerated=%.4f",
|
||||
age_hours, standard_decay, accelerated,
|
||||
)
|
||||
return accelerated
|
||||
|
||||
return standard_decay
|
||||
@@ -1,4 +1,12 @@
|
||||
"""Aggregation worker entrypoint - polls Redis for aggregation jobs."""
|
||||
"""Aggregation worker entrypoint - polls Redis for aggregation jobs.
|
||||
|
||||
After computing trend summaries for a ticker, the worker also triggers
|
||||
competitive signal propagation for the ticker's competitors when the
|
||||
competitive layer is enabled. This ensures that document intelligence
|
||||
for one company produces competitive signals for related companies.
|
||||
|
||||
Requirements: 4.1, 5.1, 9.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -8,8 +16,9 @@ import logging
|
||||
import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
from services.aggregation.worker import aggregate_company
|
||||
from services.shared.config import load_config
|
||||
from services.aggregation.signal_propagation import propagate_signals
|
||||
from services.aggregation.worker import aggregate_company, fetch_competitive_enabled
|
||||
from services.shared.config import CompetitiveConfig, load_config
|
||||
from services.shared.logging import inject_trace_context, setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_AGGREGATION,
|
||||
@@ -20,6 +29,92 @@ from services.shared.redis_keys import (
|
||||
logger = logging.getLogger("aggregation_main")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Query to fetch recent document intelligence records for a ticker.
|
||||
# Used to trigger signal propagation after aggregation completes.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_RECENT_INTELLIGENCE_QUERY = """
|
||||
SELECT
|
||||
di.document_id,
|
||||
dir.catalyst_type,
|
||||
dir.impact_score
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.id = dir.intelligence_id
|
||||
JOIN documents d ON d.id = di.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
ORDER BY d.published_at DESC
|
||||
LIMIT 50
|
||||
"""
|
||||
|
||||
|
||||
# Track consecutive propagation failures for alerting (Requirement 9.4)
|
||||
_propagation_consecutive_failures = 0
|
||||
|
||||
|
||||
async def _trigger_signal_propagation(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
competitive_config: CompetitiveConfig,
|
||||
) -> int:
|
||||
"""Trigger competitive signal propagation for a ticker's recent documents.
|
||||
|
||||
Fetches recent document intelligence records for the ticker and calls
|
||||
propagate_signals for each, producing competitive signals for the
|
||||
ticker's competitors.
|
||||
|
||||
Returns the total number of competitive signals produced.
|
||||
"""
|
||||
global _propagation_consecutive_failures
|
||||
|
||||
rows = await pool.fetch(_RECENT_INTELLIGENCE_QUERY, ticker)
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
total_signals = 0
|
||||
for row in rows:
|
||||
document_id = str(row["document_id"])
|
||||
catalyst_type = row["catalyst_type"] or "other"
|
||||
impact_score = float(row["impact_score"] or 0.0)
|
||||
|
||||
if impact_score <= 0.0:
|
||||
continue
|
||||
|
||||
try:
|
||||
records = await propagate_signals(
|
||||
pool=pool,
|
||||
ticker=ticker,
|
||||
catalyst_type=catalyst_type,
|
||||
impact_score=impact_score,
|
||||
document_id=document_id,
|
||||
config=competitive_config,
|
||||
)
|
||||
total_signals += len(records)
|
||||
|
||||
# Reset failure counter on success
|
||||
_propagation_consecutive_failures = 0
|
||||
|
||||
except Exception:
|
||||
_propagation_consecutive_failures += 1
|
||||
logger.exception(
|
||||
"Signal propagation failed for %s doc %s/%s",
|
||||
ticker, document_id, catalyst_type,
|
||||
)
|
||||
if _propagation_consecutive_failures >= competitive_config.propagation_failure_threshold:
|
||||
logger.critical(
|
||||
"ALERT: Sustained signal propagation failures (%d consecutive). "
|
||||
"Continuing with company-specific + macro signals only. "
|
||||
"Operator action required.",
|
||||
_propagation_consecutive_failures,
|
||||
)
|
||||
# Stop trying propagation for this ticker after threshold
|
||||
break
|
||||
|
||||
return total_signals
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("aggregation", level=config.log_level, json_output=config.json_logs)
|
||||
@@ -28,6 +123,7 @@ async def main() -> None:
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_AGGREGATION)
|
||||
rec_queue = queue_key(QUEUE_RECOMMENDATION)
|
||||
competitive_config = config.competitive
|
||||
logger.info("Aggregation worker started, polling %s", queue)
|
||||
|
||||
try:
|
||||
@@ -49,6 +145,32 @@ async def main() -> None:
|
||||
ticker, len(summaries),
|
||||
)
|
||||
|
||||
# Trigger competitive signal propagation after aggregation
|
||||
# (Requirement 4.1): When new document intelligence is
|
||||
# produced for a company, propagate signals to competitors.
|
||||
# Check toggle state from DB (same pattern as macro toggle).
|
||||
competitive_enabled = competitive_config.competitive_enabled
|
||||
db_toggle = await fetch_competitive_enabled(pool)
|
||||
if db_toggle is not None:
|
||||
competitive_enabled = db_toggle
|
||||
|
||||
if competitive_enabled:
|
||||
try:
|
||||
sig_count = await _trigger_signal_propagation(
|
||||
pool, ticker, competitive_config,
|
||||
)
|
||||
if sig_count > 0:
|
||||
logger.info(
|
||||
"Propagated %d competitive signals for %s",
|
||||
sig_count, ticker,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Signal propagation failed for %s — "
|
||||
"continuing with company+macro signals only",
|
||||
ticker,
|
||||
)
|
||||
|
||||
# Enqueue recommendation job for each window that produced a trend
|
||||
for summary in summaries:
|
||||
if summary.trend_strength > 0:
|
||||
|
||||
@@ -0,0 +1,414 @@
|
||||
"""Historical pattern mining for competitive intelligence.
|
||||
|
||||
Queries document_impact_records joined with trend_windows to find how
|
||||
similar catalyst types resolved historically for a company or its
|
||||
competitors. Produces HistoricalPattern objects consumed by the signal
|
||||
propagation engine and the aggregation worker.
|
||||
|
||||
Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 11.1, 11.2, 11.3, 11.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.config import CompetitiveConfig
|
||||
from services.shared.schemas import MAJOR_DECISION_CATALYSTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_HORIZONS = ["1d", "7d", "30d"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class HistoricalPattern:
|
||||
"""Statistical summary of how a catalyst type resolved historically."""
|
||||
|
||||
source_ticker: str
|
||||
target_ticker: str
|
||||
catalyst_type: str
|
||||
time_horizon: str # 1d | 7d | 30d
|
||||
sample_count: int
|
||||
bullish_pct: float # [0, 1]
|
||||
bearish_pct: float # [0, 1]
|
||||
avg_strength: float # [0, 1]
|
||||
avg_time_to_resolution: float # days
|
||||
pattern_confidence: float # [0, 1]
|
||||
data_start: datetime
|
||||
data_end: datetime
|
||||
tier: str # major_corporate_decision | routine_signal
|
||||
insufficient_data: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Catalyst tier classification (Req 11.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def classify_catalyst_tier(catalyst_type: str) -> str:
|
||||
"""Deterministic mapping of catalyst_type to tier.
|
||||
|
||||
Returns ``"major_corporate_decision"`` for catalyst types in
|
||||
MAJOR_DECISION_CATALYSTS, otherwise ``"routine_signal"``.
|
||||
"""
|
||||
if catalyst_type in MAJOR_DECISION_CATALYSTS:
|
||||
return "major_corporate_decision"
|
||||
return "routine_signal"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern confidence (Req 3.3, 11.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_pattern_confidence(
|
||||
sample_count: int,
|
||||
outcome_consistency: float,
|
||||
data_recency_days: float,
|
||||
tier: str,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> float:
|
||||
"""Compute pattern confidence score in [0, 1].
|
||||
|
||||
Formula:
|
||||
sample_factor * 0.4 + consistency * 0.4 + recency_factor * 0.2
|
||||
|
||||
With a 1.3× multiplier for ``major_corporate_decision`` tier,
|
||||
insufficient-data cap, and staleness decay.
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
|
||||
# --- component factors ---
|
||||
sample_factor = min(sample_count / 20.0, 1.0)
|
||||
consistency = outcome_consistency # already max(bullish_pct, bearish_pct)
|
||||
|
||||
if data_recency_days <= cfg.staleness_recent_days:
|
||||
recency_factor = 1.0
|
||||
elif data_recency_days <= cfg.staleness_window_days:
|
||||
recency_factor = 0.7
|
||||
else:
|
||||
recency_factor = 0.4
|
||||
|
||||
confidence = sample_factor * 0.4 + consistency * 0.4 + recency_factor * 0.2
|
||||
|
||||
# Major-decision multiplier (Req 11.2)
|
||||
if tier == "major_corporate_decision":
|
||||
confidence *= cfg.major_decision_weight_multiplier
|
||||
|
||||
# Clamp to [0, 1]
|
||||
confidence = min(max(confidence, 0.0), 1.0)
|
||||
|
||||
# Insufficient data cap (Req 3.4)
|
||||
if sample_count < cfg.min_pattern_samples:
|
||||
confidence = min(confidence, 0.25)
|
||||
|
||||
# Staleness decay (Req 9.2)
|
||||
if data_recency_days > cfg.staleness_window_days:
|
||||
confidence *= cfg.staleness_decay_penalty
|
||||
|
||||
return confidence
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookback helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _lookback_days(tier: str, config: Optional[CompetitiveConfig] = None) -> int:
|
||||
"""Return the lookback window in days for the given tier."""
|
||||
cfg = config or CompetitiveConfig()
|
||||
if tier == "major_corporate_decision":
|
||||
return cfg.major_decision_lookback_days
|
||||
return cfg.routine_lookback_days
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL: self-company pattern query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SELF_PATTERN_QUERY = """
|
||||
WITH matched_docs AS (
|
||||
SELECT
|
||||
dir.id AS dir_id,
|
||||
d.published_at,
|
||||
dir.sentiment
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.id = dir.intelligence_id
|
||||
JOIN documents d ON d.id = di.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND dir.catalyst_type = $2
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
AND d.published_at >= $3
|
||||
AND d.published_at <= $4
|
||||
)
|
||||
SELECT
|
||||
md.dir_id,
|
||||
md.published_at,
|
||||
md.sentiment,
|
||||
tw.trend_direction,
|
||||
tw.trend_strength,
|
||||
tw.generated_at,
|
||||
tw."window" AS tw_window
|
||||
FROM matched_docs md
|
||||
JOIN trend_windows tw
|
||||
ON tw.entity_type = 'company'
|
||||
AND tw.entity_id = $1
|
||||
AND tw."window" = $5
|
||||
AND tw.generated_at >= md.published_at
|
||||
AND tw.generated_at <= md.published_at + $6::interval
|
||||
ORDER BY md.published_at DESC
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL: cross-company pattern query
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CROSS_PATTERN_QUERY = """
|
||||
WITH matched_docs AS (
|
||||
SELECT
|
||||
dir.id AS dir_id,
|
||||
d.published_at,
|
||||
dir.sentiment
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.id = dir.intelligence_id
|
||||
JOIN documents d ON d.id = di.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND dir.catalyst_type = $2
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
AND d.published_at >= $3
|
||||
AND d.published_at <= $4
|
||||
)
|
||||
SELECT
|
||||
md.dir_id,
|
||||
md.published_at,
|
||||
md.sentiment,
|
||||
tw.trend_direction,
|
||||
tw.trend_strength,
|
||||
tw.generated_at,
|
||||
tw."window" AS tw_window
|
||||
FROM matched_docs md
|
||||
JOIN trend_windows tw
|
||||
ON tw.entity_type = 'company'
|
||||
AND tw.entity_id = $5
|
||||
AND tw."window" = $6
|
||||
AND tw.generated_at >= md.published_at
|
||||
AND tw.generated_at <= md.published_at + $7::interval
|
||||
ORDER BY md.published_at DESC
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Horizon → interval mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HORIZON_INTERVALS: dict[str, str] = {
|
||||
"1d": "1 day",
|
||||
"7d": "7 days",
|
||||
"30d": "30 days",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build HistoricalPattern from query rows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_pattern(
|
||||
rows: list[asyncpg.Record],
|
||||
source_ticker: str,
|
||||
target_ticker: str,
|
||||
catalyst_type: str,
|
||||
horizon: str,
|
||||
tier: str,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> Optional[HistoricalPattern]:
|
||||
"""Aggregate query rows into a single HistoricalPattern."""
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
# De-duplicate by dir_id — keep the first (closest) trend_window per doc
|
||||
seen: set[str] = set()
|
||||
unique_rows: list[asyncpg.Record] = []
|
||||
for r in rows:
|
||||
rid = str(r["dir_id"])
|
||||
if rid not in seen:
|
||||
seen.add(rid)
|
||||
unique_rows.append(r)
|
||||
|
||||
sample_count = len(unique_rows)
|
||||
|
||||
bullish = sum(1 for r in unique_rows if r["trend_direction"] == "bullish")
|
||||
bearish = sum(1 for r in unique_rows if r["trend_direction"] == "bearish")
|
||||
bullish_pct = bullish / sample_count
|
||||
bearish_pct = bearish / sample_count
|
||||
|
||||
strengths = [float(r["trend_strength"]) for r in unique_rows if r["trend_strength"] is not None]
|
||||
avg_strength = sum(strengths) / len(strengths) if strengths else 0.0
|
||||
|
||||
# avg_time_to_resolution: average days between published_at and generated_at
|
||||
resolutions: list[float] = []
|
||||
for r in unique_rows:
|
||||
pub = r["published_at"]
|
||||
gen = r["generated_at"]
|
||||
if pub and gen:
|
||||
delta = (gen - pub).total_seconds() / 86400.0
|
||||
resolutions.append(max(delta, 0.0))
|
||||
avg_time_to_resolution = sum(resolutions) / len(resolutions) if resolutions else 0.0
|
||||
|
||||
# Date range
|
||||
published_dates = [r["published_at"] for r in unique_rows if r["published_at"] is not None]
|
||||
data_start = min(published_dates)
|
||||
data_end = max(published_dates)
|
||||
|
||||
# Recency: days since the most recent data point
|
||||
now = datetime.now(timezone.utc)
|
||||
data_recency_days = (now - data_end).total_seconds() / 86400.0 if data_end else 999.0
|
||||
|
||||
outcome_consistency = max(bullish_pct, bearish_pct)
|
||||
confidence = compute_pattern_confidence(
|
||||
sample_count, outcome_consistency, data_recency_days, tier, config,
|
||||
)
|
||||
|
||||
insufficient_data = sample_count < (config or CompetitiveConfig()).min_pattern_samples
|
||||
|
||||
return HistoricalPattern(
|
||||
source_ticker=source_ticker,
|
||||
target_ticker=target_ticker,
|
||||
catalyst_type=catalyst_type,
|
||||
time_horizon=horizon,
|
||||
sample_count=sample_count,
|
||||
bullish_pct=bullish_pct,
|
||||
bearish_pct=bearish_pct,
|
||||
avg_strength=min(max(avg_strength, 0.0), 1.0),
|
||||
avg_time_to_resolution=avg_time_to_resolution,
|
||||
pattern_confidence=confidence,
|
||||
data_start=data_start,
|
||||
data_end=data_end,
|
||||
tier=tier,
|
||||
insufficient_data=insufficient_data,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def find_self_patterns(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
catalyst_type: str,
|
||||
horizons: Optional[list[str]] = None,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> list[HistoricalPattern]:
|
||||
"""Find historical patterns for the same company-catalyst pair.
|
||||
|
||||
Queries document_impact_records joined with trend_windows for the
|
||||
given ticker and catalyst_type across configurable time horizons.
|
||||
|
||||
Requirements: 3.1, 3.2, 3.5, 11.3
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
horizons = horizons or DEFAULT_HORIZONS
|
||||
tier = classify_catalyst_tier(catalyst_type)
|
||||
lookback = _lookback_days(tier, cfg)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff = now - timedelta(days=lookback)
|
||||
|
||||
patterns: list[HistoricalPattern] = []
|
||||
async with pool.acquire() as conn:
|
||||
for horizon in horizons:
|
||||
interval = _HORIZON_INTERVALS.get(horizon)
|
||||
if interval is None:
|
||||
logger.warning("Unknown horizon %s, skipping", horizon)
|
||||
continue
|
||||
try:
|
||||
rows = await conn.fetch(
|
||||
_SELF_PATTERN_QUERY,
|
||||
ticker, # $1
|
||||
catalyst_type, # $2
|
||||
cutoff, # $3
|
||||
now, # $4
|
||||
horizon, # $5
|
||||
interval, # $6
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error querying self-patterns for %s/%s/%s",
|
||||
ticker, catalyst_type, horizon,
|
||||
)
|
||||
continue
|
||||
|
||||
pattern = _build_pattern(
|
||||
rows, ticker, ticker, catalyst_type, horizon, tier, cfg,
|
||||
)
|
||||
if pattern is not None:
|
||||
patterns.append(pattern)
|
||||
|
||||
return patterns
|
||||
|
||||
|
||||
async def find_cross_company_patterns(
|
||||
pool: asyncpg.Pool,
|
||||
source_ticker: str,
|
||||
target_ticker: str,
|
||||
catalyst_type: str,
|
||||
horizons: Optional[list[str]] = None,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> list[HistoricalPattern]:
|
||||
"""Find cross-company historical patterns.
|
||||
|
||||
Queries documents about *source_ticker* with the given catalyst_type,
|
||||
then looks at trend_windows for *target_ticker* within each horizon
|
||||
after the document was published.
|
||||
|
||||
Requirements: 4.2, 11.5
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
horizons = horizons or DEFAULT_HORIZONS
|
||||
tier = classify_catalyst_tier(catalyst_type)
|
||||
lookback = _lookback_days(tier, cfg)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
cutoff = now - timedelta(days=lookback)
|
||||
|
||||
patterns: list[HistoricalPattern] = []
|
||||
async with pool.acquire() as conn:
|
||||
for horizon in horizons:
|
||||
interval = _HORIZON_INTERVALS.get(horizon)
|
||||
if interval is None:
|
||||
logger.warning("Unknown horizon %s, skipping", horizon)
|
||||
continue
|
||||
try:
|
||||
rows = await conn.fetch(
|
||||
_CROSS_PATTERN_QUERY,
|
||||
source_ticker, # $1
|
||||
catalyst_type, # $2
|
||||
cutoff, # $3
|
||||
now, # $4
|
||||
target_ticker, # $5
|
||||
horizon, # $6
|
||||
interval, # $7
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error querying cross-patterns for %s→%s/%s/%s",
|
||||
source_ticker, target_ticker, catalyst_type, horizon,
|
||||
)
|
||||
continue
|
||||
|
||||
pattern = _build_pattern(
|
||||
rows, source_ticker, target_ticker, catalyst_type,
|
||||
horizon, tier, cfg,
|
||||
)
|
||||
if pattern is not None:
|
||||
patterns.append(pattern)
|
||||
|
||||
return patterns
|
||||
@@ -0,0 +1,416 @@
|
||||
"""Trend projection module — forward-looking trend estimates.
|
||||
|
||||
Computes TrendProjection objects by combining current trend momentum,
|
||||
macro signal decay trajectories, and upcoming catalyst outlook.
|
||||
Projections are persisted alongside trend_window records.
|
||||
|
||||
Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.9
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.shared.schemas import TrendDirection, TrendSummary
|
||||
|
||||
logger = logging.getLogger("projection")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TrendProjection dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_DIRECTIONS = {"bullish", "bearish", "mixed", "neutral"}
|
||||
VALID_HORIZONS = {"1d", "7d", "30d"}
|
||||
|
||||
# Default low-confidence threshold
|
||||
DEFAULT_CONFIDENCE_THRESHOLD = 0.3
|
||||
|
||||
# Macro signal decay half-lives (in days) by estimated_duration
|
||||
DECAY_HALF_LIFE_DAYS: dict[str, float] = {
|
||||
"short_term": 1.0, # halve impact per day
|
||||
"medium_term": 7.0, # halve impact per week
|
||||
"long_term": 30.0, # halve impact per month
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrendProjection:
|
||||
"""Forward-looking trend projection for a company."""
|
||||
|
||||
projected_direction: str = "neutral" # bullish|bearish|mixed|neutral
|
||||
projected_strength: float = 0.5 # [0, 1]
|
||||
projected_confidence: float = 0.5 # [0, 1]
|
||||
projection_horizon: str = "7d" # 1d|7d|30d
|
||||
driving_factors: list[str] = field(default_factory=list)
|
||||
macro_contribution_pct: float = 0.0 # [0, 1]
|
||||
diverges_from_current: bool = False
|
||||
computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
low_confidence: bool = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Macro impact row type (lightweight, avoids circular import with worker)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MacroEventInfo:
|
||||
"""Minimal macro event info needed for projection computation."""
|
||||
|
||||
event_id: str = ""
|
||||
macro_impact_score: float = 0.0
|
||||
impact_direction: str = "neutral"
|
||||
confidence: float = 0.5
|
||||
estimated_duration: str = "short_term"
|
||||
severity: str = "low"
|
||||
event_age_hours: float = 0.0 # hours since event publication
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Projection horizon mapping from trend window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_WINDOW_TO_HORIZON: dict[str, str] = {
|
||||
"intraday": "1d",
|
||||
"1d": "1d",
|
||||
"7d": "7d",
|
||||
"30d": "30d",
|
||||
"90d": "30d",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Momentum computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compute_trend_momentum(
|
||||
current_strength: float,
|
||||
current_direction: str,
|
||||
previous_strength: float | None = None,
|
||||
previous_direction: str | None = None,
|
||||
) -> float:
|
||||
"""Compute trend momentum as rate of change in signed strength.
|
||||
|
||||
Returns a value in [-1, 1] representing the momentum:
|
||||
- Positive = strengthening bullish or weakening bearish
|
||||
- Negative = strengthening bearish or weakening bullish
|
||||
- Zero = no change or no previous data
|
||||
|
||||
When no previous data is available, uses a simple heuristic based
|
||||
on current strength and direction.
|
||||
"""
|
||||
dir_sign = _direction_sign(current_direction)
|
||||
|
||||
if previous_strength is None or previous_direction is None:
|
||||
# Heuristic: assume momentum proportional to current signed strength
|
||||
return round(dir_sign * current_strength * 0.5, 6)
|
||||
|
||||
prev_sign = _direction_sign(previous_direction)
|
||||
current_signed = dir_sign * current_strength
|
||||
previous_signed = prev_sign * previous_strength
|
||||
|
||||
momentum = current_signed - previous_signed
|
||||
return round(max(-1.0, min(1.0, momentum)), 6)
|
||||
|
||||
|
||||
def _direction_sign(direction: str) -> float:
|
||||
"""Map direction to a sign multiplier."""
|
||||
if direction == "bullish":
|
||||
return 1.0
|
||||
elif direction == "bearish":
|
||||
return -1.0
|
||||
return 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Macro signal decay projection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SEVERITY_WEIGHT: dict[str, float] = {
|
||||
"critical": 1.0,
|
||||
"high": 0.75,
|
||||
"moderate": 0.5,
|
||||
"low": 0.25,
|
||||
}
|
||||
|
||||
|
||||
def project_macro_decay(
|
||||
events: list[MacroEventInfo],
|
||||
horizon_days: float,
|
||||
) -> tuple[float, str]:
|
||||
"""Project the aggregate macro signal after decay over the horizon.
|
||||
|
||||
For each active macro event, compute the projected remaining impact
|
||||
using exponential decay based on estimated_duration:
|
||||
- short_term: half-life = 1 day
|
||||
- medium_term: half-life = 7 days
|
||||
- long_term: half-life = 30 days
|
||||
|
||||
Returns:
|
||||
(projected_macro_strength, projected_macro_direction)
|
||||
where strength is in [0, 1] and direction is bullish|bearish|mixed|neutral.
|
||||
"""
|
||||
if not events:
|
||||
return 0.0, "neutral"
|
||||
|
||||
positive_weight = 0.0
|
||||
negative_weight = 0.0
|
||||
|
||||
for ev in events:
|
||||
half_life = DECAY_HALF_LIFE_DAYS.get(ev.estimated_duration, 7.0)
|
||||
# Current age in days
|
||||
current_age_days = ev.event_age_hours / 24.0
|
||||
# Projected age at end of horizon
|
||||
future_age_days = current_age_days + horizon_days
|
||||
|
||||
# Decay factor: ratio of future impact to current impact
|
||||
if half_life > 0:
|
||||
current_factor = math.pow(2.0, -current_age_days / half_life)
|
||||
future_factor = math.pow(2.0, -future_age_days / half_life)
|
||||
else:
|
||||
current_factor = 0.0
|
||||
future_factor = 0.0
|
||||
|
||||
severity_w = _SEVERITY_WEIGHT.get(ev.severity, 0.25)
|
||||
projected_impact = ev.macro_impact_score * future_factor * severity_w
|
||||
|
||||
if ev.impact_direction == "positive":
|
||||
positive_weight += projected_impact
|
||||
elif ev.impact_direction == "negative":
|
||||
negative_weight += projected_impact
|
||||
else:
|
||||
# mixed/neutral: split evenly
|
||||
positive_weight += projected_impact * 0.5
|
||||
negative_weight += projected_impact * 0.5
|
||||
|
||||
total = positive_weight + negative_weight
|
||||
if total == 0.0:
|
||||
return 0.0, "neutral"
|
||||
|
||||
strength = min(total, 1.0)
|
||||
|
||||
if positive_weight > negative_weight * 1.2:
|
||||
direction = "bullish"
|
||||
elif negative_weight > positive_weight * 1.2:
|
||||
direction = "bearish"
|
||||
elif positive_weight > 0 and negative_weight > 0:
|
||||
direction = "mixed"
|
||||
else:
|
||||
direction = "neutral"
|
||||
|
||||
return round(strength, 6), direction
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Horizon days mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HORIZON_DAYS: dict[str, float] = {
|
||||
"1d": 1.0,
|
||||
"7d": 7.0,
|
||||
"30d": 30.0,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core projection computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_projection(
|
||||
summary: TrendSummary,
|
||||
macro_events: list[MacroEventInfo] | None = None,
|
||||
macro_enabled: bool = True,
|
||||
confidence_threshold: float = DEFAULT_CONFIDENCE_THRESHOLD,
|
||||
previous_strength: float | None = None,
|
||||
previous_direction: str | None = None,
|
||||
upcoming_catalysts: list[str] | None = None,
|
||||
) -> TrendProjection:
|
||||
"""Compute a forward-looking trend projection.
|
||||
|
||||
Combines:
|
||||
1. Trend momentum (rate of change in strength)
|
||||
2. Macro signal decay projection
|
||||
3. Upcoming catalyst outlook
|
||||
4. Current trend baseline
|
||||
|
||||
Args:
|
||||
summary: The current trend summary.
|
||||
macro_events: Active macro events with their info.
|
||||
macro_enabled: Whether the macro layer is enabled.
|
||||
confidence_threshold: Below this, mark as low_confidence.
|
||||
previous_strength: Previous window's trend strength (optional).
|
||||
previous_direction: Previous window's trend direction (optional).
|
||||
upcoming_catalysts: Known upcoming catalysts from doc intelligence.
|
||||
|
||||
Returns:
|
||||
A TrendProjection with projected direction, strength, and confidence.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
current_dir = summary.trend_direction.value
|
||||
current_strength = summary.trend_strength
|
||||
current_confidence = summary.confidence
|
||||
|
||||
horizon = _WINDOW_TO_HORIZON.get(summary.window.value, "7d")
|
||||
horizon_days = _HORIZON_DAYS.get(horizon, 7.0)
|
||||
|
||||
driving_factors: list[str] = []
|
||||
|
||||
# 1. Compute trend momentum
|
||||
momentum = compute_trend_momentum(
|
||||
current_strength, current_dir,
|
||||
previous_strength, previous_direction,
|
||||
)
|
||||
if abs(momentum) > 0.05:
|
||||
if momentum > 0:
|
||||
driving_factors.append(f"Positive momentum ({momentum:+.3f}) in recent trend strength")
|
||||
else:
|
||||
driving_factors.append(f"Negative momentum ({momentum:+.3f}) in recent trend strength")
|
||||
|
||||
# 2. Project macro signal decay
|
||||
macro_strength = 0.0
|
||||
macro_direction = "neutral"
|
||||
macro_contribution = 0.0
|
||||
|
||||
if macro_enabled and macro_events:
|
||||
macro_strength, macro_direction = project_macro_decay(macro_events, horizon_days)
|
||||
if macro_strength > 0:
|
||||
driving_factors.append(
|
||||
f"Macro signals project {macro_direction} impact "
|
||||
f"(strength {macro_strength:.3f}) over {horizon}"
|
||||
)
|
||||
|
||||
# 3. Factor in upcoming catalysts
|
||||
catalysts = upcoming_catalysts or []
|
||||
for catalyst in catalysts[:3]: # limit to top 3
|
||||
driving_factors.append(f"Upcoming catalyst: {catalyst}")
|
||||
|
||||
catalyst_boost = min(len(catalysts) * 0.02, 0.1) # small boost per catalyst
|
||||
|
||||
# 4. Combine into projected direction/strength/confidence
|
||||
# Momentum-based projection of company-specific trend
|
||||
momentum_projected_signed = _direction_sign(current_dir) * current_strength + momentum * 0.5
|
||||
momentum_projected_strength = min(abs(momentum_projected_signed), 1.0)
|
||||
|
||||
if macro_enabled and macro_strength > 0:
|
||||
# Blend company momentum with macro trajectory
|
||||
macro_weight = min(macro_strength * 0.4, 0.4)
|
||||
company_weight = 1.0 - macro_weight
|
||||
|
||||
macro_signed = _direction_sign(macro_direction) * macro_strength
|
||||
blended_signed = (
|
||||
company_weight * momentum_projected_signed
|
||||
+ macro_weight * macro_signed
|
||||
)
|
||||
projected_strength = round(min(abs(blended_signed) + catalyst_boost, 1.0), 6)
|
||||
macro_contribution = round(macro_weight, 6)
|
||||
|
||||
# Determine projected direction from blended signal
|
||||
projected_direction = _signed_to_direction(blended_signed)
|
||||
else:
|
||||
# Company-only projection
|
||||
projected_strength = round(min(momentum_projected_strength + catalyst_boost, 1.0), 6)
|
||||
projected_direction = _signed_to_direction(momentum_projected_signed)
|
||||
|
||||
# Compute projected confidence
|
||||
base_confidence = current_confidence * 0.8 # projection inherently less certain
|
||||
if macro_enabled and macro_strength > 0:
|
||||
# Macro data adds information → slight confidence boost
|
||||
macro_conf_boost = min(macro_strength * 0.15, 0.1)
|
||||
projected_confidence = round(min(base_confidence + macro_conf_boost, 1.0), 6)
|
||||
else:
|
||||
# Without macro data, reduce confidence further
|
||||
if not macro_enabled:
|
||||
projected_confidence = round(base_confidence * 0.85, 6)
|
||||
else:
|
||||
projected_confidence = round(base_confidence, 6)
|
||||
|
||||
# Ensure driving_factors is never empty
|
||||
if not driving_factors:
|
||||
driving_factors.append(f"Baseline trend continuation: {current_dir} at strength {current_strength:.3f}")
|
||||
|
||||
# 5. Flag divergence
|
||||
diverges = projected_direction != current_dir
|
||||
if diverges:
|
||||
driving_factors.append(
|
||||
f"DIVERGENCE: Current trend is {current_dir}, "
|
||||
f"projection is {projected_direction}"
|
||||
)
|
||||
|
||||
# Mark low confidence
|
||||
is_low_confidence = projected_confidence < confidence_threshold
|
||||
|
||||
return TrendProjection(
|
||||
projected_direction=projected_direction,
|
||||
projected_strength=projected_strength,
|
||||
projected_confidence=projected_confidence,
|
||||
projection_horizon=horizon,
|
||||
driving_factors=driving_factors,
|
||||
macro_contribution_pct=macro_contribution,
|
||||
diverges_from_current=diverges,
|
||||
computed_at=now,
|
||||
low_confidence=is_low_confidence,
|
||||
)
|
||||
|
||||
|
||||
def _signed_to_direction(signed_value: float) -> str:
|
||||
"""Convert a signed strength value to a direction string."""
|
||||
if signed_value > 0.1:
|
||||
return "bullish"
|
||||
elif signed_value < -0.1:
|
||||
return "bearish"
|
||||
elif abs(signed_value) > 0.02:
|
||||
return "mixed"
|
||||
return "neutral"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INSERT_PROJECTION = """
|
||||
INSERT INTO trend_projections (
|
||||
trend_window_id, projected_direction, projected_strength,
|
||||
projected_confidence, projection_horizon, driving_factors,
|
||||
macro_contribution_pct, diverges_from_current, computed_at
|
||||
) VALUES (
|
||||
$1::uuid, $2, $3, $4, $5, $6::jsonb, $7, $8, $9
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
|
||||
|
||||
async def persist_trend_projection(
|
||||
pool: asyncpg.Pool,
|
||||
trend_window_id: str,
|
||||
projection: TrendProjection,
|
||||
) -> str:
|
||||
"""Persist a TrendProjection to the trend_projections table.
|
||||
|
||||
Returns the row UUID.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
_INSERT_PROJECTION,
|
||||
trend_window_id,
|
||||
projection.projected_direction,
|
||||
projection.projected_strength,
|
||||
projection.projected_confidence,
|
||||
projection.projection_horizon,
|
||||
json.dumps(projection.driving_factors),
|
||||
projection.macro_contribution_pct,
|
||||
projection.diverges_from_current,
|
||||
projection.computed_at,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted trend projection for window=%s: direction=%s strength=%.3f confidence=%.3f diverges=%s",
|
||||
trend_window_id,
|
||||
projection.projected_direction,
|
||||
projection.projected_strength,
|
||||
projection.projected_confidence,
|
||||
projection.diverges_from_current,
|
||||
)
|
||||
return str(row_id)
|
||||
+226
-13
@@ -4,13 +4,13 @@ Aggregates company-level trend summaries into sector and market-level
|
||||
summaries, enabling top-down views of sentiment and risk across the
|
||||
portfolio.
|
||||
|
||||
Requirements: 6.3, 6.4, 6.5
|
||||
Requirements: 6.1, 6.2, 6.3, 6.4, 6.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import asyncpg
|
||||
@@ -42,6 +42,126 @@ class CompanyTrendRow:
|
||||
top_opposing_evidence: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SectorMacroImpact:
|
||||
"""Aggregated macro impact data for a single sector.
|
||||
|
||||
Used to incorporate macro signals into sector and market rollups.
|
||||
Requirements: 6.1, 6.2, 6.3
|
||||
"""
|
||||
|
||||
sector: str
|
||||
total_impact: float # sum of macro_impact_score across companies in sector
|
||||
avg_impact: float # average macro_impact_score
|
||||
company_count: int # number of companies affected
|
||||
net_direction: float # weighted direction: +1 positive, -1 negative, 0 mixed
|
||||
event_ids: list[str] = field(default_factory=list) # contributing event IDs
|
||||
|
||||
|
||||
# Threshold for disproportionate sector impact (Requirement 6.3)
|
||||
SECTOR_CONCENTRATION_THRESHOLD = 0.60
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch sector-level macro impact aggregates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SECTOR_MACRO_IMPACT_QUERY = """
|
||||
SELECT
|
||||
c.sector,
|
||||
mir.event_id,
|
||||
mir.macro_impact_score,
|
||||
mir.impact_direction
|
||||
FROM macro_impact_records mir
|
||||
JOIN companies c ON c.id = mir.company_id AND c.active = TRUE
|
||||
WHERE mir.computed_at >= $1
|
||||
AND mir.computed_at <= $2
|
||||
ORDER BY c.sector, mir.macro_impact_score DESC
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_sector_macro_impacts(
|
||||
pool: asyncpg.Pool,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> dict[str, SectorMacroImpact]:
|
||||
"""Fetch macro impact records aggregated by sector for a time range.
|
||||
|
||||
Returns a mapping of sector name to SectorMacroImpact.
|
||||
"""
|
||||
rows = await pool.fetch(_SECTOR_MACRO_IMPACT_QUERY, window_start, window_end)
|
||||
|
||||
# Accumulate per-sector
|
||||
sector_data: dict[str, dict] = {}
|
||||
direction_map = {"positive": 1.0, "negative": -1.0, "mixed": 0.0, "neutral": 0.0}
|
||||
|
||||
for row in rows:
|
||||
sector = str(row["sector"]) if row["sector"] else "Unknown"
|
||||
score = float(row["macro_impact_score"] or 0.0)
|
||||
direction = row["impact_direction"] or "neutral"
|
||||
event_id = str(row["event_id"])
|
||||
|
||||
if sector not in sector_data:
|
||||
sector_data[sector] = {
|
||||
"total": 0.0,
|
||||
"count": 0,
|
||||
"dir_sum": 0.0,
|
||||
"dir_count": 0,
|
||||
"event_ids": set(),
|
||||
}
|
||||
|
||||
d = sector_data[sector]
|
||||
d["total"] += score
|
||||
d["count"] += 1
|
||||
dir_val = direction_map.get(direction, 0.0)
|
||||
if dir_val != 0.0:
|
||||
d["dir_sum"] += dir_val
|
||||
d["dir_count"] += 1
|
||||
d["event_ids"].add(event_id)
|
||||
|
||||
result: dict[str, SectorMacroImpact] = {}
|
||||
for sector, d in sector_data.items():
|
||||
count = d["count"]
|
||||
avg = d["total"] / count if count > 0 else 0.0
|
||||
net_dir = d["dir_sum"] / d["dir_count"] if d["dir_count"] > 0 else 0.0
|
||||
result[sector] = SectorMacroImpact(
|
||||
sector=sector,
|
||||
total_impact=d["total"],
|
||||
avg_impact=avg,
|
||||
company_count=count,
|
||||
net_direction=net_dir,
|
||||
event_ids=sorted(d["event_ids"]),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sector macro concentration helper (Requirement 6.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_sector_macro_concentration(
|
||||
sector_impacts: dict[str, SectorMacroImpact],
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Compute the fraction of total macro impact concentrated in each sector.
|
||||
|
||||
Returns a list of (sector, fraction) tuples sorted by fraction descending.
|
||||
Sectors with fraction > SECTOR_CONCENTRATION_THRESHOLD are considered
|
||||
disproportionately affected.
|
||||
"""
|
||||
total = sum(si.total_impact for si in sector_impacts.values())
|
||||
if total <= 0.0:
|
||||
return []
|
||||
|
||||
fractions = [
|
||||
(sector, si.total_impact / total)
|
||||
for sector, si in sector_impacts.items()
|
||||
]
|
||||
fractions.sort(key=lambda x: x[1], reverse=True)
|
||||
return fractions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch latest company trends for a given window
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,11 +261,22 @@ def rollup_trends(
|
||||
entity_id: str,
|
||||
window: str,
|
||||
reference_time: datetime,
|
||||
macro_impacts: dict[str, SectorMacroImpact] | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Aggregate a list of company-level trends into a single rollup summary.
|
||||
|
||||
Each company trend is weighted by its confidence to produce a
|
||||
confidence-weighted average of direction, strength, and contradiction.
|
||||
|
||||
When macro_impacts is provided:
|
||||
- For sector rollups: incorporates the sector's macro signal into
|
||||
strength and confidence, weighted by constituent company exposure.
|
||||
- For market rollups: aggregates macro signals across all sectors and
|
||||
surfaces disproportionately affected sectors (>60% concentration)
|
||||
in material_risks or dominant_catalysts.
|
||||
|
||||
When macro_impacts is None or empty, produces identical output to
|
||||
the original company-only rollup.
|
||||
"""
|
||||
if not trends:
|
||||
return TrendSummary(
|
||||
@@ -204,16 +335,70 @@ def rollup_trends(
|
||||
avg_contradiction = weighted_contradiction / total_weight
|
||||
avg_confidence = total_weight / len(trends)
|
||||
|
||||
# --- Incorporate macro impact signals when available ---
|
||||
macro_strength_adj = 0.0
|
||||
macro_confidence_adj = 0.0
|
||||
macro_catalysts: list[str] = []
|
||||
macro_risks: list[str] = []
|
||||
|
||||
if macro_impacts:
|
||||
if entity_type == "sector":
|
||||
# Sector rollup: incorporate this sector's macro signal
|
||||
sector_macro = macro_impacts.get(entity_id)
|
||||
if sector_macro and sector_macro.total_impact > 0:
|
||||
# Weight macro contribution by avg impact and company breadth
|
||||
breadth = min(sector_macro.company_count / max(len(trends), 1), 1.0)
|
||||
macro_strength_adj = sector_macro.avg_impact * breadth * 0.3
|
||||
macro_confidence_adj = sector_macro.avg_impact * breadth * 0.1
|
||||
# Nudge direction based on macro net direction
|
||||
avg_direction += sector_macro.net_direction * macro_strength_adj * 0.5
|
||||
|
||||
elif entity_type == "market":
|
||||
# Market rollup: aggregate macro signals across all sectors
|
||||
total_macro = sum(si.total_impact for si in macro_impacts.values())
|
||||
if total_macro > 0:
|
||||
total_companies = sum(si.company_count for si in macro_impacts.values())
|
||||
breadth = min(total_companies / max(len(trends), 1), 1.0)
|
||||
avg_macro = total_macro / max(len(macro_impacts), 1)
|
||||
macro_strength_adj = avg_macro * breadth * 0.3
|
||||
macro_confidence_adj = avg_macro * breadth * 0.1
|
||||
|
||||
# Aggregate net direction across sectors
|
||||
dir_sum = sum(
|
||||
si.net_direction * si.total_impact
|
||||
for si in macro_impacts.values()
|
||||
)
|
||||
net_dir = dir_sum / total_macro if total_macro > 0 else 0.0
|
||||
avg_direction += net_dir * macro_strength_adj * 0.5
|
||||
|
||||
# Surface disproportionately affected sectors (Requirement 6.3)
|
||||
concentration = compute_sector_macro_concentration(macro_impacts)
|
||||
for sector, fraction in concentration:
|
||||
if fraction > SECTOR_CONCENTRATION_THRESHOLD:
|
||||
si = macro_impacts[sector]
|
||||
label = f"Macro: {sector} ({fraction:.0%} of macro impact)"
|
||||
if si.net_direction < 0:
|
||||
macro_risks.append(label)
|
||||
else:
|
||||
macro_catalysts.append(label)
|
||||
|
||||
# Apply macro adjustments to strength and confidence
|
||||
adj_strength = avg_strength + macro_strength_adj
|
||||
adj_confidence = avg_confidence + macro_confidence_adj
|
||||
|
||||
# Derive direction
|
||||
direction = _derive_rollup_direction(avg_direction, avg_contradiction)
|
||||
|
||||
# Top catalysts
|
||||
# Top catalysts (macro catalysts prepended when present)
|
||||
sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True)
|
||||
catalysts = [c for c, _ in sorted_catalysts[:5]]
|
||||
catalysts = macro_catalysts + [c for c, _ in sorted_catalysts[:5]]
|
||||
catalysts = catalysts[:5]
|
||||
|
||||
# Top risks (deduplicated, by weight)
|
||||
# Top risks (macro risks prepended when present, deduplicated)
|
||||
sorted_risks = sorted(risk_set.items(), key=lambda x: x[1], reverse=True)
|
||||
risks = [r for r, _ in sorted_risks[:5]]
|
||||
base_risks = [r for r, _ in sorted_risks[:5]]
|
||||
risks = macro_risks + base_risks
|
||||
risks = risks[:5]
|
||||
|
||||
# Disagreement details
|
||||
disagreement = _build_rollup_disagreement(trends, entity_id)
|
||||
@@ -223,8 +408,8 @@ def rollup_trends(
|
||||
entity_id=entity_id,
|
||||
window=TrendWindow(window),
|
||||
trend_direction=direction,
|
||||
trend_strength=round(min(abs(avg_strength), 1.0), 4),
|
||||
confidence=round(max(0.0, min(avg_confidence, 1.0)), 4),
|
||||
trend_strength=round(min(abs(adj_strength), 1.0), 4),
|
||||
confidence=round(max(0.0, min(adj_confidence, 1.0)), 4),
|
||||
top_supporting_evidence=list(dict.fromkeys(all_supporting))[:10],
|
||||
top_opposing_evidence=list(dict.fromkeys(all_opposing))[:10],
|
||||
dominant_catalysts=catalysts,
|
||||
@@ -341,11 +526,14 @@ async def aggregate_sector(
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
macro_impacts: dict[str, SectorMacroImpact] | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Compute and persist a sector-level rollup for one window.
|
||||
|
||||
Fetches the latest company trends, filters to the given sector,
|
||||
and rolls them up into a single sector summary.
|
||||
and rolls them up into a single sector summary. When macro_impacts
|
||||
is provided, incorporates macro signals weighted by constituent
|
||||
company exposure.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
@@ -355,7 +543,14 @@ async def aggregate_sector(
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
sector_trends = [t for t in all_trends if t.sector == sector]
|
||||
|
||||
summary = rollup_trends(sector_trends, "sector", sector, window, reference_time)
|
||||
# Fetch macro impacts if not provided
|
||||
if macro_impacts is None:
|
||||
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
|
||||
|
||||
summary = rollup_trends(
|
||||
sector_trends, "sector", sector, window, reference_time,
|
||||
macro_impacts=macro_impacts,
|
||||
)
|
||||
|
||||
if sector_trends:
|
||||
rollup_id = await persist_rollup(pool, summary)
|
||||
@@ -373,10 +568,13 @@ async def aggregate_market(
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
macro_impacts: dict[str, SectorMacroImpact] | None = None,
|
||||
) -> TrendSummary:
|
||||
"""Compute and persist a market-wide rollup for one window.
|
||||
|
||||
Aggregates all company trends regardless of sector.
|
||||
Aggregates all company trends regardless of sector. When macro_impacts
|
||||
is provided, aggregates macro signals across all sectors and surfaces
|
||||
disproportionately affected sectors in material_risks or dominant_catalysts.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
@@ -385,7 +583,14 @@ async def aggregate_market(
|
||||
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
|
||||
summary = rollup_trends(all_trends, "market", "all", window, reference_time)
|
||||
# Fetch macro impacts if not provided
|
||||
if macro_impacts is None:
|
||||
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
|
||||
|
||||
summary = rollup_trends(
|
||||
all_trends, "market", "all", window, reference_time,
|
||||
macro_impacts=macro_impacts,
|
||||
)
|
||||
|
||||
if all_trends:
|
||||
rollup_id = await persist_rollup(pool, summary)
|
||||
@@ -403,6 +608,7 @@ async def aggregate_all_sectors(
|
||||
window: str,
|
||||
reference_time: datetime | None = None,
|
||||
since: datetime | None = None,
|
||||
macro_impacts: dict[str, SectorMacroImpact] | None = None,
|
||||
) -> list[TrendSummary]:
|
||||
"""Compute sector rollups for every sector that has company trends."""
|
||||
if reference_time is None:
|
||||
@@ -412,6 +618,10 @@ async def aggregate_all_sectors(
|
||||
|
||||
all_trends = await fetch_latest_company_trends(pool, window, since)
|
||||
|
||||
# Fetch macro impacts once for all sectors if not provided
|
||||
if macro_impacts is None:
|
||||
macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time)
|
||||
|
||||
# Group by sector
|
||||
sectors: dict[str, list[CompanyTrendRow]] = {}
|
||||
for t in all_trends:
|
||||
@@ -419,7 +629,10 @@ async def aggregate_all_sectors(
|
||||
|
||||
summaries: list[TrendSummary] = []
|
||||
for sector, trends in sectors.items():
|
||||
summary = rollup_trends(trends, "sector", sector, window, reference_time)
|
||||
summary = rollup_trends(
|
||||
trends, "sector", sector, window, reference_time,
|
||||
macro_impacts=macro_impacts,
|
||||
)
|
||||
if trends:
|
||||
_id = await persist_rollup(pool, summary)
|
||||
summaries.append(summary)
|
||||
|
||||
@@ -0,0 +1,306 @@
|
||||
"""Competitive signal propagation engine.
|
||||
|
||||
Evaluates incoming document intelligence, identifies competitors via
|
||||
the competitor_relationships table, queries historical cross-company
|
||||
patterns, and produces weighted competitive signals persisted to
|
||||
competitive_signal_records.
|
||||
|
||||
Also converts pattern and competitive signals into WeightedSignal
|
||||
objects for the aggregation engine.
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 9.1
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.aggregation.pattern_matcher import (
|
||||
HistoricalPattern,
|
||||
find_cross_company_patterns,
|
||||
)
|
||||
from services.aggregation.scoring import (
|
||||
ScoringConfig,
|
||||
WeightedSignal,
|
||||
compute_signal_weight,
|
||||
)
|
||||
from services.shared.config import CompetitiveConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class CompetitiveSignalRecord:
|
||||
"""A competitive signal produced by propagating a source event to a
|
||||
competitor based on historical cross-company patterns."""
|
||||
|
||||
source_document_id: str
|
||||
source_ticker: str
|
||||
target_ticker: str
|
||||
catalyst_type: str
|
||||
pattern_confidence: float
|
||||
signal_direction: str # bullish | bearish
|
||||
signal_strength: float # [0, 1]
|
||||
relationship_strength: float
|
||||
computed_at: datetime
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQL queries
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMPETITOR_LOOKUP_QUERY = """
|
||||
SELECT company_a_id, company_b_id, strength
|
||||
FROM competitor_relationships
|
||||
WHERE (company_a_id = $1 OR company_b_id = $1)
|
||||
AND active = TRUE
|
||||
"""
|
||||
|
||||
_INSERT_SIGNAL_QUERY = """
|
||||
INSERT INTO competitive_signal_records
|
||||
(source_document_id, source_ticker, target_ticker, catalyst_type,
|
||||
pattern_confidence, signal_direction, signal_strength,
|
||||
relationship_strength, computed_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# propagate_signals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def propagate_signals(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
catalyst_type: str,
|
||||
impact_score: float,
|
||||
document_id: str,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> list[CompetitiveSignalRecord]:
|
||||
"""Look up competitors, query cross-company patterns, produce weighted
|
||||
competitive signals, and persist them.
|
||||
|
||||
Args:
|
||||
pool: asyncpg connection pool.
|
||||
ticker: Source company ticker that received the catalyst.
|
||||
catalyst_type: The catalyst type from document intelligence.
|
||||
impact_score: The source document's impact score.
|
||||
document_id: The source document ID.
|
||||
config: Optional competitive config overrides.
|
||||
|
||||
Returns:
|
||||
List of CompetitiveSignalRecord objects produced and persisted.
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
now = datetime.now(timezone.utc)
|
||||
records: list[CompetitiveSignalRecord] = []
|
||||
|
||||
# Step 1: Look up active competitors
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(_COMPETITOR_LOOKUP_QUERY, ticker)
|
||||
except Exception:
|
||||
logger.exception("Failed to look up competitors for %s", ticker)
|
||||
return records
|
||||
|
||||
if not rows:
|
||||
logger.debug("No active competitors found for %s", ticker)
|
||||
return records
|
||||
|
||||
# Step 2: For each competitor, query cross-company patterns
|
||||
for row in rows:
|
||||
company_a = str(row["company_a_id"])
|
||||
company_b = str(row["company_b_id"])
|
||||
rel_strength = float(row["strength"])
|
||||
|
||||
# Determine the competitor ticker (the other side of the relationship)
|
||||
competitor_ticker = company_b if company_a == ticker else company_a
|
||||
|
||||
# Threshold gating (Req 4.5)
|
||||
if rel_strength < cfg.propagation_strength_threshold:
|
||||
logger.info(
|
||||
"Skipping propagation %s→%s: relationship strength %.3f "
|
||||
"below threshold %.3f",
|
||||
ticker, competitor_ticker, rel_strength,
|
||||
cfg.propagation_strength_threshold,
|
||||
)
|
||||
continue
|
||||
|
||||
# Query cross-company patterns
|
||||
try:
|
||||
patterns = await find_cross_company_patterns(
|
||||
pool, ticker, competitor_ticker, catalyst_type, config=cfg,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to query cross-company patterns for %s→%s/%s",
|
||||
ticker, competitor_ticker, catalyst_type,
|
||||
)
|
||||
continue
|
||||
|
||||
for pattern in patterns:
|
||||
# Confidence threshold gating (Req 9.1)
|
||||
if pattern.pattern_confidence < cfg.pattern_confidence_threshold:
|
||||
logger.info(
|
||||
"Excluding pattern %s→%s/%s/%s: confidence %.3f "
|
||||
"below threshold %.3f",
|
||||
ticker, competitor_ticker, catalyst_type,
|
||||
pattern.time_horizon, pattern.pattern_confidence,
|
||||
cfg.pattern_confidence_threshold,
|
||||
)
|
||||
continue
|
||||
|
||||
# Compute signal strength (Req 4.3)
|
||||
raw_strength = (
|
||||
pattern.avg_strength
|
||||
* rel_strength
|
||||
* pattern.pattern_confidence
|
||||
* impact_score
|
||||
)
|
||||
signal_strength = min(max(raw_strength, 0.0), 1.0)
|
||||
|
||||
# Determine direction
|
||||
direction = (
|
||||
"bullish" if pattern.bullish_pct > pattern.bearish_pct
|
||||
else "bearish"
|
||||
)
|
||||
|
||||
record = CompetitiveSignalRecord(
|
||||
source_document_id=document_id,
|
||||
source_ticker=ticker,
|
||||
target_ticker=competitor_ticker,
|
||||
catalyst_type=catalyst_type,
|
||||
pattern_confidence=pattern.pattern_confidence,
|
||||
signal_direction=direction,
|
||||
signal_strength=signal_strength,
|
||||
relationship_strength=rel_strength,
|
||||
computed_at=now,
|
||||
)
|
||||
records.append(record)
|
||||
|
||||
# Step 3: Persist all records
|
||||
if records:
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
await conn.executemany(
|
||||
_INSERT_SIGNAL_QUERY,
|
||||
[
|
||||
(
|
||||
r.source_document_id,
|
||||
r.source_ticker,
|
||||
r.target_ticker,
|
||||
r.catalyst_type,
|
||||
r.pattern_confidence,
|
||||
r.signal_direction,
|
||||
r.signal_strength,
|
||||
r.relationship_strength,
|
||||
r.computed_at,
|
||||
)
|
||||
for r in records
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to persist %d competitive signal records", len(records),
|
||||
)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_pattern_weighted_signals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_pattern_weighted_signals(
|
||||
patterns: list[HistoricalPattern],
|
||||
competitive_signals: list[CompetitiveSignalRecord],
|
||||
reference_time: datetime,
|
||||
window: str,
|
||||
config: Optional[CompetitiveConfig] = None,
|
||||
) -> list[WeightedSignal]:
|
||||
"""Convert pattern and competitive signal objects to WeightedSignal
|
||||
objects for the aggregation engine.
|
||||
|
||||
For HistoricalPattern objects:
|
||||
- sentiment_value = +1.0 if bullish_pct > bearish_pct else -1.0
|
||||
- impact_score = avg_strength * competitive_signal_weight
|
||||
- published_at = data_end (most recent data point for recency decay)
|
||||
- extraction_confidence = pattern_confidence
|
||||
|
||||
For CompetitiveSignalRecord objects:
|
||||
- sentiment_value = +1.0 if direction == "bullish" else -1.0
|
||||
- impact_score = signal_strength * competitive_signal_weight
|
||||
- published_at = computed_at (for recency decay)
|
||||
- extraction_confidence = pattern_confidence
|
||||
|
||||
Args:
|
||||
patterns: Self-company historical patterns.
|
||||
competitive_signals: Competitive signal records from propagation.
|
||||
reference_time: Aggregation anchor time for recency decay.
|
||||
window: Trend window identifier (e.g. "7d").
|
||||
config: Optional competitive config overrides.
|
||||
|
||||
Returns:
|
||||
List of WeightedSignal objects ready for aggregation.
|
||||
"""
|
||||
cfg = config or CompetitiveConfig()
|
||||
scoring_cfg = ScoringConfig()
|
||||
signals: list[WeightedSignal] = []
|
||||
|
||||
# Convert HistoricalPattern objects
|
||||
for pattern in patterns:
|
||||
sentiment_value = (
|
||||
1.0 if pattern.bullish_pct > pattern.bearish_pct else -1.0
|
||||
)
|
||||
impact = pattern.avg_strength * cfg.competitive_signal_weight
|
||||
|
||||
weight = compute_signal_weight(
|
||||
published_at=pattern.data_end,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
source_credibility=1.0, # patterns are derived from validated data
|
||||
novelty_score=0.5,
|
||||
extraction_confidence=pattern.pattern_confidence,
|
||||
market_ctx=None,
|
||||
config=scoring_cfg,
|
||||
)
|
||||
|
||||
signals.append(WeightedSignal(
|
||||
document_id=f"pattern:{pattern.source_ticker}:{pattern.catalyst_type}:{pattern.time_horizon}",
|
||||
weight=weight,
|
||||
sentiment_value=sentiment_value,
|
||||
impact_score=impact,
|
||||
))
|
||||
|
||||
# Convert CompetitiveSignalRecord objects
|
||||
for sig in competitive_signals:
|
||||
sentiment_value = 1.0 if sig.signal_direction == "bullish" else -1.0
|
||||
impact = sig.signal_strength * cfg.competitive_signal_weight
|
||||
|
||||
weight = compute_signal_weight(
|
||||
published_at=sig.computed_at,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
source_credibility=1.0,
|
||||
novelty_score=0.5,
|
||||
extraction_confidence=sig.pattern_confidence,
|
||||
market_ctx=None,
|
||||
config=scoring_cfg,
|
||||
)
|
||||
|
||||
signals.append(WeightedSignal(
|
||||
document_id=sig.source_document_id,
|
||||
weight=weight,
|
||||
sentiment_value=sentiment_value,
|
||||
impact_score=impact,
|
||||
))
|
||||
|
||||
return signals
|
||||
@@ -40,6 +40,17 @@ from services.shared.metrics import (
|
||||
AGGREGATION_SIGNALS_PROCESSED,
|
||||
AGGREGATION_WINDOWS_COMPUTED,
|
||||
)
|
||||
from services.aggregation.pattern_matcher import find_self_patterns
|
||||
from services.aggregation.projection import (
|
||||
MacroEventInfo,
|
||||
TrendProjection,
|
||||
compute_projection,
|
||||
persist_trend_projection,
|
||||
)
|
||||
from services.aggregation.signal_propagation import (
|
||||
CompetitiveSignalRecord,
|
||||
build_pattern_weighted_signals,
|
||||
)
|
||||
from services.shared.schemas import TrendDirection, TrendSummary, TrendWindow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,6 +75,10 @@ class AggregationConfig:
|
||||
windows: list[str] | None = None # None = all windows
|
||||
scoring: ScoringConfig | None = None
|
||||
max_evidence: int = MAX_EVIDENCE_REFS
|
||||
macro_signal_weight: float = 0.3 # relative weight of macro vs company signals
|
||||
macro_enabled: bool = True # runtime toggle state
|
||||
competitive_signal_weight: float = 0.2 # relative weight of pattern signals
|
||||
competitive_enabled: bool = True # runtime toggle state
|
||||
|
||||
def effective_windows(self) -> list[str]:
|
||||
if self.windows:
|
||||
@@ -154,6 +169,236 @@ async def fetch_impact_records(
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch macro toggle state from risk_configs
|
||||
#
|
||||
# MACRO LAYER TOGGLE BEHAVIOR (Requirements 11.2, 11.3, 11.4):
|
||||
# - The toggle state is read fresh from PostgreSQL at the start of each
|
||||
# aggregation cycle (no caching), so changes take effect immediately on
|
||||
# the next cycle.
|
||||
# - When disabled: ingestion and classification continue normally (historical
|
||||
# data is preserved), but interpolation and aggregation integration are
|
||||
# skipped — the aggregation engine produces trends using only company-
|
||||
# specific signals.
|
||||
# - When re-enabled: the engine resumes computing macro impact scores using
|
||||
# the most recent GlobalEvent classifications, including any events that
|
||||
# were ingested and classified while the layer was disabled.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MACRO_TOGGLE_QUERY = """
|
||||
SELECT config->>'macro_enabled' AS macro_enabled
|
||||
FROM risk_configs
|
||||
WHERE active = TRUE
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_macro_enabled(pool: asyncpg.Pool) -> bool | None:
|
||||
"""Check macro toggle state from risk_configs table.
|
||||
|
||||
Returns True/False if explicitly set, or None if no config exists
|
||||
(caller should fall back to AggregationConfig default).
|
||||
"""
|
||||
row = await pool.fetchrow(_MACRO_TOGGLE_QUERY)
|
||||
if row is None or row["macro_enabled"] is None:
|
||||
return None
|
||||
return row["macro_enabled"].lower() == "true"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch competitive toggle state from risk_configs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMPETITIVE_TOGGLE_QUERY = """
|
||||
SELECT config->>'competitive_enabled' AS competitive_enabled
|
||||
FROM risk_configs
|
||||
WHERE active = TRUE
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_competitive_enabled(pool: asyncpg.Pool) -> bool | None:
|
||||
"""Check competitive toggle state from risk_configs table.
|
||||
|
||||
Returns True/False if explicitly set, or None if no config exists
|
||||
(caller should fall back to AggregationConfig default).
|
||||
"""
|
||||
row = await pool.fetchrow(_COMPETITIVE_TOGGLE_QUERY)
|
||||
if row is None or row["competitive_enabled"] is None:
|
||||
return None
|
||||
return row["competitive_enabled"].lower() == "true"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch competitive signals targeting a ticker within a time window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMPETITIVE_SIGNALS_QUERY = """
|
||||
SELECT source_document_id, source_ticker, target_ticker, catalyst_type,
|
||||
pattern_confidence, signal_direction, signal_strength,
|
||||
relationship_strength, computed_at
|
||||
FROM competitive_signal_records
|
||||
WHERE target_ticker = $1
|
||||
AND computed_at >= $2
|
||||
AND computed_at <= $3
|
||||
ORDER BY computed_at DESC
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_competitive_signals(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> list[CompetitiveSignalRecord]:
|
||||
"""Fetch competitive signal records targeting a ticker in a time range."""
|
||||
rows = await pool.fetch(
|
||||
_COMPETITIVE_SIGNALS_QUERY, ticker, window_start, window_end,
|
||||
)
|
||||
return [
|
||||
CompetitiveSignalRecord(
|
||||
source_document_id=str(row["source_document_id"]),
|
||||
source_ticker=row["source_ticker"],
|
||||
target_ticker=row["target_ticker"],
|
||||
catalyst_type=row["catalyst_type"],
|
||||
pattern_confidence=float(row["pattern_confidence"]),
|
||||
signal_direction=row["signal_direction"],
|
||||
signal_strength=float(row["signal_strength"]),
|
||||
relationship_strength=float(row["relationship_strength"]),
|
||||
computed_at=row["computed_at"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch macro impact records for a ticker within a time window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MACRO_IMPACT_QUERY = """
|
||||
SELECT
|
||||
mir.event_id,
|
||||
mir.company_id,
|
||||
mir.ticker,
|
||||
mir.macro_impact_score,
|
||||
mir.impact_direction,
|
||||
mir.contributing_factors,
|
||||
mir.confidence,
|
||||
mir.computed_at,
|
||||
ge.source_document_id,
|
||||
d.published_at AS event_published_at
|
||||
FROM macro_impact_records mir
|
||||
JOIN global_events ge ON ge.id = mir.event_id
|
||||
JOIN documents d ON d.id = ge.source_document_id
|
||||
WHERE mir.ticker = $1
|
||||
AND mir.computed_at >= $2
|
||||
AND mir.computed_at <= $3
|
||||
ORDER BY mir.computed_at DESC
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacroImpactRow:
|
||||
"""Parsed row from the macro impact query."""
|
||||
|
||||
event_id: str
|
||||
company_id: str
|
||||
ticker: str
|
||||
macro_impact_score: float
|
||||
impact_direction: str
|
||||
contributing_factors: list[str]
|
||||
confidence: float
|
||||
computed_at: datetime
|
||||
source_document_id: str
|
||||
event_published_at: datetime
|
||||
|
||||
|
||||
def _parse_macro_impact_row(row: Any) -> MacroImpactRow:
|
||||
"""Convert an asyncpg Record to a MacroImpactRow."""
|
||||
factors = row["contributing_factors"]
|
||||
if isinstance(factors, str):
|
||||
factors = json.loads(factors)
|
||||
|
||||
return MacroImpactRow(
|
||||
event_id=str(row["event_id"]),
|
||||
company_id=str(row["company_id"]),
|
||||
ticker=row["ticker"],
|
||||
macro_impact_score=float(row["macro_impact_score"] or 0.0),
|
||||
impact_direction=row["impact_direction"] or "neutral",
|
||||
contributing_factors=factors if isinstance(factors, list) else [],
|
||||
confidence=float(row["confidence"] or 0.5),
|
||||
computed_at=row["computed_at"],
|
||||
source_document_id=str(row["source_document_id"]),
|
||||
event_published_at=row["event_published_at"],
|
||||
)
|
||||
|
||||
|
||||
async def fetch_macro_impact_records(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window_start: datetime,
|
||||
window_end: datetime,
|
||||
) -> list[MacroImpactRow]:
|
||||
"""Fetch macro impact records for a ticker in a time range."""
|
||||
rows = await pool.fetch(_MACRO_IMPACT_QUERY, ticker, window_start, window_end)
|
||||
return [_parse_macro_impact_row(r) for r in rows]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convert macro impact records to WeightedSignals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DIRECTION_TO_SENTIMENT: dict[str, float] = {
|
||||
"positive": 1.0,
|
||||
"negative": -1.0,
|
||||
"mixed": 0.0,
|
||||
"neutral": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def build_macro_weighted_signals(
|
||||
macro_impacts: list[MacroImpactRow],
|
||||
reference_time: datetime,
|
||||
window: str,
|
||||
macro_signal_weight: float = 0.3,
|
||||
config: ScoringConfig | None = None,
|
||||
) -> list[WeightedSignal]:
|
||||
"""Convert macro impact records into WeightedSignal objects.
|
||||
|
||||
Uses the same scoring pipeline as company signals:
|
||||
- document_id = source_document_id (for evidence tracing)
|
||||
- sentiment_value mapped from impact_direction
|
||||
- impact_score = macro_impact_score * macro_signal_weight
|
||||
- recency decay from the global event's publication time
|
||||
- confidence gating from the macro record's confidence
|
||||
"""
|
||||
cfg = config or ScoringConfig()
|
||||
signals: list[WeightedSignal] = []
|
||||
for mir in macro_impacts:
|
||||
sw = compute_signal_weight(
|
||||
published_at=mir.event_published_at,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
source_credibility=mir.confidence,
|
||||
novelty_score=0.5,
|
||||
extraction_confidence=mir.confidence,
|
||||
config=cfg,
|
||||
)
|
||||
sentiment = _DIRECTION_TO_SENTIMENT.get(mir.impact_direction, 0.0)
|
||||
impact = mir.macro_impact_score * macro_signal_weight
|
||||
signals.append(
|
||||
WeightedSignal(
|
||||
document_id=mir.source_document_id,
|
||||
weight=sw,
|
||||
sentiment_value=sentiment,
|
||||
impact_score=impact,
|
||||
)
|
||||
)
|
||||
return signals
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build weighted signals from impact records
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -544,6 +789,61 @@ async def persist_trend_evidence(
|
||||
return len(rows)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build MacroEventInfo objects for projection computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MACRO_EVENT_INFO_QUERY = """
|
||||
SELECT
|
||||
mir.event_id,
|
||||
mir.macro_impact_score,
|
||||
mir.impact_direction,
|
||||
mir.confidence,
|
||||
ge.estimated_duration,
|
||||
ge.severity,
|
||||
d.published_at AS event_published_at
|
||||
FROM macro_impact_records mir
|
||||
JOIN global_events ge ON ge.id = mir.event_id
|
||||
JOIN documents d ON d.id = ge.source_document_id
|
||||
WHERE mir.ticker = $1
|
||||
AND mir.computed_at >= $2
|
||||
AND mir.computed_at <= $3
|
||||
ORDER BY mir.computed_at DESC
|
||||
"""
|
||||
|
||||
|
||||
async def _build_macro_event_infos(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window_start: datetime,
|
||||
reference_time: datetime,
|
||||
) -> list[MacroEventInfo]:
|
||||
"""Fetch macro impact records and build MacroEventInfo objects for projection."""
|
||||
rows = await pool.fetch(
|
||||
_MACRO_EVENT_INFO_QUERY, ticker, window_start, reference_time,
|
||||
)
|
||||
infos: list[MacroEventInfo] = []
|
||||
for row in rows:
|
||||
published_at = row["event_published_at"]
|
||||
age_hours = 0.0
|
||||
if published_at:
|
||||
age_hours = max(
|
||||
(reference_time - published_at).total_seconds() / 3600.0, 0.0,
|
||||
)
|
||||
infos.append(
|
||||
MacroEventInfo(
|
||||
event_id=str(row["event_id"]),
|
||||
macro_impact_score=float(row["macro_impact_score"] or 0.0),
|
||||
impact_direction=row["impact_direction"] or "neutral",
|
||||
confidence=float(row["confidence"] or 0.5),
|
||||
estimated_duration=row["estimated_duration"] or "short_term",
|
||||
severity=row["severity"] or "low",
|
||||
event_age_hours=age_hours,
|
||||
)
|
||||
)
|
||||
return infos
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main aggregation entry point for a single ticker + window
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -563,8 +863,10 @@ async def aggregate_company_window(
|
||||
2. Fetch document impact records from PostgreSQL.
|
||||
3. Fetch market context for the ticker.
|
||||
4. Build weighted signals using the scoring module.
|
||||
5. Assemble the TrendSummary.
|
||||
6. Persist to trend_windows table.
|
||||
5. Check macro toggle and fetch/merge macro signals if enabled.
|
||||
6. Check competitive toggle and fetch/merge pattern/competitive signals if enabled.
|
||||
7. Assemble the TrendSummary.
|
||||
8. Persist to trend_windows table.
|
||||
|
||||
Returns the assembled TrendSummary.
|
||||
"""
|
||||
@@ -589,7 +891,83 @@ async def aggregate_company_window(
|
||||
impacts, reference_time, window, market_ctx, scoring_cfg,
|
||||
)
|
||||
|
||||
# 4. Assemble trend summary with evidence details
|
||||
# 4. Check macro toggle and merge macro signals
|
||||
# (Requirement 11.2, 11.3, 11.4): Toggle state is read from the DB on
|
||||
# every aggregation cycle. When disabled, macro signals are skipped but
|
||||
# ingestion/classification continue independently — so when re-enabled,
|
||||
# the most recent classifications (including those ingested while disabled)
|
||||
# are immediately available for impact computation.
|
||||
macro_enabled = cfg.macro_enabled
|
||||
db_toggle = await fetch_macro_enabled(pool)
|
||||
if db_toggle is not None:
|
||||
macro_enabled = db_toggle
|
||||
|
||||
if macro_enabled:
|
||||
macro_impacts = await fetch_macro_impact_records(
|
||||
pool, ticker, window_start, reference_time,
|
||||
)
|
||||
if macro_impacts:
|
||||
macro_signals = build_macro_weighted_signals(
|
||||
macro_impacts,
|
||||
reference_time,
|
||||
window,
|
||||
macro_signal_weight=cfg.macro_signal_weight,
|
||||
config=scoring_cfg,
|
||||
)
|
||||
signals = signals + macro_signals
|
||||
logger.info(
|
||||
"Merged %d macro signals for %s/%s",
|
||||
len(macro_signals), ticker, window,
|
||||
)
|
||||
|
||||
# 5. Check competitive toggle and merge pattern/competitive signals
|
||||
# (Requirements 5.1-5.6): Same toggle pattern as macro layer. When
|
||||
# disabled, pattern mining remains queryable but aggregation skips
|
||||
# competitive signals — no degradation of existing behavior.
|
||||
competitive_enabled = cfg.competitive_enabled
|
||||
db_competitive_toggle = await fetch_competitive_enabled(pool)
|
||||
if db_competitive_toggle is not None:
|
||||
competitive_enabled = db_competitive_toggle
|
||||
|
||||
if competitive_enabled:
|
||||
try:
|
||||
# Get unique catalyst types from the impact records
|
||||
catalyst_types = {imp.catalyst_type for imp in impacts}
|
||||
|
||||
# Query self-company historical patterns for each catalyst type
|
||||
all_patterns = []
|
||||
for cat_type in catalyst_types:
|
||||
patterns = await find_self_patterns(pool, ticker, cat_type)
|
||||
all_patterns.extend(patterns)
|
||||
|
||||
# Fetch competitive signals targeting this ticker
|
||||
comp_signals = await fetch_competitive_signals(
|
||||
pool, ticker, window_start, reference_time,
|
||||
)
|
||||
|
||||
# Convert to WeightedSignal objects
|
||||
if all_patterns or comp_signals:
|
||||
pattern_weighted = build_pattern_weighted_signals(
|
||||
patterns=all_patterns,
|
||||
competitive_signals=comp_signals,
|
||||
reference_time=reference_time,
|
||||
window=window,
|
||||
)
|
||||
signals = signals + pattern_weighted
|
||||
logger.info(
|
||||
"Merged %d pattern/competitive signals for %s/%s "
|
||||
"(patterns=%d, competitive=%d)",
|
||||
len(pattern_weighted), ticker, window,
|
||||
len(all_patterns), len(comp_signals),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to fetch pattern/competitive signals for %s/%s — "
|
||||
"continuing with company+macro signals only",
|
||||
ticker, window,
|
||||
)
|
||||
|
||||
# 6. Assemble trend summary with evidence details
|
||||
assembled = assemble_trend_with_evidence(
|
||||
ticker=ticker,
|
||||
window=window,
|
||||
@@ -601,10 +979,10 @@ async def aggregate_company_window(
|
||||
)
|
||||
summary = assembled.summary
|
||||
|
||||
# 5. Persist trend window
|
||||
# 7. Persist trend window
|
||||
trend_id = await persist_trend_summary(pool, summary)
|
||||
|
||||
# 6. Persist evidence mappings
|
||||
# 8. Persist evidence mappings
|
||||
evidence_count = await persist_trend_evidence(
|
||||
pool, trend_id,
|
||||
assembled.supporting_evidence,
|
||||
@@ -617,6 +995,33 @@ async def aggregate_company_window(
|
||||
summary.trend_strength, summary.confidence, len(signals), evidence_count,
|
||||
)
|
||||
|
||||
# 9. Compute and persist trend projection
|
||||
try:
|
||||
macro_event_infos: list[MacroEventInfo] = []
|
||||
if macro_enabled:
|
||||
macro_event_infos = await _build_macro_event_infos(
|
||||
pool, ticker, window_start, reference_time,
|
||||
)
|
||||
|
||||
projection = compute_projection(
|
||||
summary=summary,
|
||||
macro_events=macro_event_infos if macro_event_infos else None,
|
||||
macro_enabled=macro_enabled,
|
||||
upcoming_catalysts=summary.dominant_catalysts[:3] if summary.dominant_catalysts else None,
|
||||
)
|
||||
await persist_trend_projection(pool, trend_id, projection)
|
||||
logger.info(
|
||||
"Persisted projection for %s/%s: direction=%s strength=%.3f confidence=%.3f diverges=%s",
|
||||
ticker, window, projection.projected_direction,
|
||||
projection.projected_strength, projection.projected_confidence,
|
||||
projection.diverges_from_current,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to compute/persist projection for trend %s (%s/%s) — continuing",
|
||||
trend_id, ticker, window,
|
||||
)
|
||||
|
||||
# Prometheus metrics
|
||||
AGGREGATION_WINDOWS_COMPUTED.labels(window=window).inc()
|
||||
AGGREGATION_SIGNALS_PROCESSED.labels(window=window).inc(len(signals))
|
||||
|
||||
+597
-1
@@ -28,7 +28,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import Response
|
||||
|
||||
from services.extractor.metrics import get_model_performance_summary
|
||||
from services.shared.audit import get_entity_audit_trail, get_order_audit_trail
|
||||
from services.shared.audit import get_entity_audit_trail, get_order_audit_trail, record_audit_event
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_pg_pool
|
||||
from services.shared.logging import new_trace_id, set_trace_context, setup_logging
|
||||
@@ -376,6 +376,24 @@ async def list_trends(
|
||||
):
|
||||
d[jsonb_field] = _parse_jsonb(d.get(jsonb_field))
|
||||
results.append(d)
|
||||
|
||||
# Include projection data for each trend (Requirement 12.10)
|
||||
if results:
|
||||
trend_ids = [r["id"] for r in rows]
|
||||
proj_rows = await pool.fetch(
|
||||
"""SELECT DISTINCT ON (trend_window_id)
|
||||
trend_window_id, projected_direction, projected_strength,
|
||||
projected_confidence, projection_horizon,
|
||||
macro_contribution_pct, diverges_from_current
|
||||
FROM trend_projections
|
||||
WHERE trend_window_id = ANY($1::uuid[])
|
||||
ORDER BY trend_window_id, computed_at DESC""",
|
||||
trend_ids,
|
||||
)
|
||||
proj_map = {str(p["trend_window_id"]): _row_to_dict(p) for p in proj_rows}
|
||||
for d in results:
|
||||
d["projection"] = proj_map.get(d["id"])
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -1687,3 +1705,581 @@ async def delete_saved_query(query_id: str):
|
||||
if result == "DELETE 0":
|
||||
raise HTTPException(404, "Query not found")
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin: Macro Signal Layer Toggle (Requirement 11.1, 11.5, 11.7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MacroToggleBody(BaseModel):
|
||||
enabled: bool
|
||||
operator: str = "operator"
|
||||
|
||||
|
||||
@app.get("/api/admin/macro/status")
|
||||
async def get_macro_status():
|
||||
"""Return the current macro signal layer enabled/disabled state.
|
||||
|
||||
Reads from the active risk_configs row's JSONB config field.
|
||||
Requirements: 11.1, 11.5
|
||||
"""
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT config->>'macro_enabled' AS macro_enabled
|
||||
FROM risk_configs
|
||||
WHERE active = TRUE
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1""",
|
||||
)
|
||||
if row is None or row["macro_enabled"] is None:
|
||||
return {"macro_enabled": True, "source": "default"}
|
||||
return {
|
||||
"macro_enabled": row["macro_enabled"].lower() == "true",
|
||||
"source": "risk_configs",
|
||||
}
|
||||
|
||||
|
||||
@app.put("/api/admin/macro/toggle")
|
||||
async def toggle_macro_layer(body: MacroToggleBody):
|
||||
"""Toggle the macro signal layer on or off.
|
||||
|
||||
Persists the new state into the active risk_configs row's JSONB config
|
||||
and records an audit event with previous state, new state, and operator.
|
||||
|
||||
The toggle state is read from PostgreSQL at the start of each aggregation
|
||||
cycle (no caching), so changes take effect on the next cycle.
|
||||
|
||||
Requirements: 11.1, 11.5, 11.7
|
||||
"""
|
||||
# Read current state
|
||||
current_row = await pool.fetchrow(
|
||||
"""SELECT id, config->>'macro_enabled' AS macro_enabled
|
||||
FROM risk_configs
|
||||
WHERE active = TRUE
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1""",
|
||||
)
|
||||
|
||||
if current_row is None:
|
||||
# No active config exists — create one
|
||||
new_config = json.dumps({"macro_enabled": str(body.enabled).lower()})
|
||||
current_row = await pool.fetchrow(
|
||||
"""INSERT INTO risk_configs (name, trading_mode, config, active)
|
||||
VALUES ('default', 'paper', $1::jsonb, TRUE)
|
||||
RETURNING id, config->>'macro_enabled' AS macro_enabled""",
|
||||
new_config,
|
||||
)
|
||||
previous_enabled = True # default was enabled
|
||||
else:
|
||||
prev_val = current_row["macro_enabled"]
|
||||
previous_enabled = prev_val.lower() == "true" if prev_val else True
|
||||
|
||||
config_id = str(current_row["id"])
|
||||
|
||||
# Update the config JSONB to set macro_enabled
|
||||
await pool.execute(
|
||||
"""UPDATE risk_configs
|
||||
SET config = config || $2::jsonb, updated_at = NOW()
|
||||
WHERE id = $1""",
|
||||
current_row["id"],
|
||||
json.dumps({"macro_enabled": str(body.enabled).lower()}),
|
||||
)
|
||||
|
||||
# Record audit event (Requirement 11.7)
|
||||
await record_audit_event(
|
||||
pool,
|
||||
event_type="macro.layer_toggled",
|
||||
entity_type="risk_config",
|
||||
entity_id=config_id,
|
||||
data={
|
||||
"previous_enabled": previous_enabled,
|
||||
"new_enabled": body.enabled,
|
||||
},
|
||||
actor=body.operator,
|
||||
)
|
||||
|
||||
return {
|
||||
"macro_enabled": body.enabled,
|
||||
"previous_enabled": previous_enabled,
|
||||
"toggled_by": body.operator,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Macro Events and Impacts (Requirement 8.1, 8.2, 12.10)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/macro/events")
|
||||
async def list_macro_events(
|
||||
severity: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
sector: Optional[str] = None,
|
||||
since: Optional[str] = None,
|
||||
until: Optional[str] = None,
|
||||
limit: int = Query(default=50, le=200),
|
||||
offset: int = 0,
|
||||
):
|
||||
"""List recent global events with filtering by severity, region, sector, date range.
|
||||
|
||||
Requirements: 8.1
|
||||
"""
|
||||
conditions: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
|
||||
if severity:
|
||||
conditions.append(f"ge.severity = ${idx}")
|
||||
params.append(severity)
|
||||
idx += 1
|
||||
if region:
|
||||
conditions.append(f"${idx} = ANY(ge.affected_regions)")
|
||||
params.append(region)
|
||||
idx += 1
|
||||
if sector:
|
||||
conditions.append(f"${idx} = ANY(ge.affected_sectors)")
|
||||
params.append(sector)
|
||||
idx += 1
|
||||
if since:
|
||||
conditions.append(f"ge.created_at >= ${idx}::timestamptz")
|
||||
params.append(since)
|
||||
idx += 1
|
||||
if until:
|
||||
conditions.append(f"ge.created_at <= ${idx}::timestamptz")
|
||||
params.append(until)
|
||||
idx += 1
|
||||
|
||||
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
|
||||
|
||||
rows = await pool.fetch(
|
||||
f"""SELECT ge.id, ge.event_types, ge.severity, ge.affected_regions,
|
||||
ge.affected_sectors, ge.affected_commodities, ge.summary,
|
||||
ge.key_facts, ge.estimated_duration, ge.confidence,
|
||||
ge.source_document_id, ge.created_at
|
||||
FROM global_events ge
|
||||
{where}
|
||||
ORDER BY ge.created_at DESC
|
||||
LIMIT ${idx} OFFSET ${idx + 1}""",
|
||||
*params, limit, offset,
|
||||
)
|
||||
results = []
|
||||
for r in rows:
|
||||
d = _row_to_dict(r)
|
||||
d["key_facts"] = _parse_jsonb(d.get("key_facts"))
|
||||
results.append(d)
|
||||
return results
|
||||
|
||||
|
||||
@app.get("/api/macro/events/{event_id}")
|
||||
async def get_macro_event(event_id: str):
|
||||
"""Event detail with list of affected companies and their macro impact scores.
|
||||
|
||||
Requirements: 8.2
|
||||
"""
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT id, event_types, severity, affected_regions, affected_sectors,
|
||||
affected_commodities, summary, key_facts, estimated_duration,
|
||||
confidence, source_document_id, model_provider, model_name,
|
||||
prompt_version, schema_version, created_at
|
||||
FROM global_events WHERE id = $1""",
|
||||
event_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(404, "Global event not found")
|
||||
|
||||
result = _row_to_dict(row)
|
||||
result["key_facts"] = _parse_jsonb(result.get("key_facts"))
|
||||
|
||||
# Affected companies with macro impact scores
|
||||
impacts = await pool.fetch(
|
||||
"""SELECT mir.id, mir.company_id, mir.ticker, mir.macro_impact_score,
|
||||
mir.impact_direction, mir.contributing_factors, mir.confidence,
|
||||
mir.computed_at, c.legal_name, c.sector
|
||||
FROM macro_impact_records mir
|
||||
JOIN companies c ON c.id = mir.company_id
|
||||
WHERE mir.event_id = $1
|
||||
ORDER BY mir.macro_impact_score DESC""",
|
||||
event_id,
|
||||
)
|
||||
impact_list = []
|
||||
for imp in impacts:
|
||||
imp_dict = _row_to_dict(imp)
|
||||
imp_dict["contributing_factors"] = _parse_jsonb(imp_dict.get("contributing_factors"))
|
||||
impact_list.append(imp_dict)
|
||||
result["affected_companies"] = impact_list
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/macro/impacts/{ticker}")
|
||||
async def get_macro_impacts_for_ticker(
|
||||
ticker: str,
|
||||
since: Optional[str] = None,
|
||||
limit: int = Query(default=50, le=200),
|
||||
offset: int = 0,
|
||||
):
|
||||
"""Macro impacts for a specific company.
|
||||
|
||||
Requirements: 8.2
|
||||
"""
|
||||
conditions = ["mir.ticker = $1"]
|
||||
params: list[Any] = [ticker.upper()]
|
||||
idx = 2
|
||||
|
||||
if since:
|
||||
conditions.append(f"mir.computed_at >= ${idx}::timestamptz")
|
||||
params.append(since)
|
||||
idx += 1
|
||||
|
||||
where = " AND ".join(conditions)
|
||||
|
||||
rows = await pool.fetch(
|
||||
f"""SELECT mir.id, mir.event_id, mir.company_id, mir.ticker,
|
||||
mir.macro_impact_score, mir.impact_direction,
|
||||
mir.contributing_factors, mir.confidence, mir.computed_at,
|
||||
ge.summary AS event_summary, ge.severity AS event_severity,
|
||||
ge.event_types AS event_types, ge.affected_regions
|
||||
FROM macro_impact_records mir
|
||||
JOIN global_events ge ON ge.id = mir.event_id
|
||||
WHERE {where}
|
||||
ORDER BY mir.computed_at DESC
|
||||
LIMIT ${idx} OFFSET ${idx + 1}""",
|
||||
*params, limit, offset,
|
||||
)
|
||||
results = []
|
||||
for r in rows:
|
||||
d = _row_to_dict(r)
|
||||
d["contributing_factors"] = _parse_jsonb(d.get("contributing_factors"))
|
||||
results.append(d)
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trend Projections (Requirement 12.10)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/trends/{trend_id}/projection")
|
||||
async def get_trend_projection(trend_id: str):
|
||||
"""Trend projection for a specific trend window.
|
||||
|
||||
Requirements: 12.10
|
||||
"""
|
||||
# Verify trend exists
|
||||
trend_row = await pool.fetchrow(
|
||||
"SELECT id FROM trend_windows WHERE id = $1", trend_id,
|
||||
)
|
||||
if not trend_row:
|
||||
raise HTTPException(404, "Trend not found")
|
||||
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT id, trend_window_id, projected_direction, projected_strength,
|
||||
projected_confidence, projection_horizon, driving_factors,
|
||||
macro_contribution_pct, diverges_from_current, computed_at
|
||||
FROM trend_projections WHERE trend_window_id = $1
|
||||
ORDER BY computed_at DESC LIMIT 1""",
|
||||
trend_id,
|
||||
)
|
||||
if not row:
|
||||
return {"trend_window_id": trend_id, "projection": None}
|
||||
|
||||
d = _row_to_dict(row)
|
||||
d["driving_factors"] = _parse_jsonb(d.get("driving_factors"))
|
||||
return d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Competitive Layer Toggle (Requirements 6.1, 6.2, 6.3, 6.4, 6.5, 6.7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class CompetitiveToggleBody(BaseModel):
|
||||
enabled: bool
|
||||
operator: str = "operator"
|
||||
|
||||
|
||||
@app.get("/api/admin/competitive/status")
|
||||
async def get_competitive_status():
|
||||
"""Return the current competitive signal layer enabled/disabled state.
|
||||
|
||||
Reads from the active risk_configs row's JSONB config field.
|
||||
Requirements: 6.1, 6.5
|
||||
"""
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT config->>'competitive_enabled' AS competitive_enabled
|
||||
FROM risk_configs
|
||||
WHERE active = TRUE
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1""",
|
||||
)
|
||||
if row is None or row["competitive_enabled"] is None:
|
||||
return {"competitive_enabled": True, "source": "default"}
|
||||
return {
|
||||
"competitive_enabled": row["competitive_enabled"].lower() == "true",
|
||||
"source": "risk_configs",
|
||||
}
|
||||
|
||||
|
||||
@app.put("/api/admin/competitive/toggle")
|
||||
async def toggle_competitive_layer(body: CompetitiveToggleBody):
|
||||
"""Toggle the competitive signal layer on or off.
|
||||
|
||||
Persists the new state into the active risk_configs row's JSONB config
|
||||
and records an audit event with previous state, new state, and operator.
|
||||
|
||||
Toggle state is read from PostgreSQL at the start of each aggregation
|
||||
cycle (no caching), so changes take effect on the next cycle.
|
||||
|
||||
When disabled, pattern mining remains queryable via API but signal
|
||||
propagation is skipped during aggregation. When re-enabled, the engine
|
||||
resumes computing signals using latest historical data including
|
||||
intelligence ingested while disabled.
|
||||
|
||||
Requirements: 6.1, 6.2, 6.3, 6.4, 6.5, 6.7
|
||||
"""
|
||||
# Read current state
|
||||
current_row = await pool.fetchrow(
|
||||
"""SELECT id, config->>'competitive_enabled' AS competitive_enabled
|
||||
FROM risk_configs
|
||||
WHERE active = TRUE
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT 1""",
|
||||
)
|
||||
|
||||
if current_row is None:
|
||||
# No active config exists — create one
|
||||
new_config = json.dumps({"competitive_enabled": str(body.enabled).lower()})
|
||||
current_row = await pool.fetchrow(
|
||||
"""INSERT INTO risk_configs (name, trading_mode, config, active)
|
||||
VALUES ('default', 'paper', $1::jsonb, TRUE)
|
||||
RETURNING id, config->>'competitive_enabled' AS competitive_enabled""",
|
||||
new_config,
|
||||
)
|
||||
previous_enabled = True # default was enabled
|
||||
else:
|
||||
prev_val = current_row["competitive_enabled"]
|
||||
previous_enabled = prev_val.lower() == "true" if prev_val else True
|
||||
|
||||
config_id = str(current_row["id"])
|
||||
|
||||
# Update the config JSONB to set competitive_enabled
|
||||
await pool.execute(
|
||||
"""UPDATE risk_configs
|
||||
SET config = config || $2::jsonb, updated_at = NOW()
|
||||
WHERE id = $1""",
|
||||
current_row["id"],
|
||||
json.dumps({"competitive_enabled": str(body.enabled).lower()}),
|
||||
)
|
||||
|
||||
# Record audit event (Requirement 6.7)
|
||||
await record_audit_event(
|
||||
pool,
|
||||
event_type="competitive.layer_toggled",
|
||||
entity_type="risk_config",
|
||||
entity_id=config_id,
|
||||
data={
|
||||
"previous_enabled": previous_enabled,
|
||||
"new_enabled": body.enabled,
|
||||
},
|
||||
actor=body.operator,
|
||||
)
|
||||
|
||||
return {
|
||||
"competitive_enabled": body.enabled,
|
||||
"previous_enabled": previous_enabled,
|
||||
"toggled_by": body.operator,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Historical Pattern & Competitive Signal Query Endpoints
|
||||
# (Requirements 10.1, 10.2, 10.3, 10.4, 11.4, 11.6)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
from services.aggregation.pattern_matcher import (
|
||||
find_cross_company_patterns,
|
||||
find_self_patterns,
|
||||
)
|
||||
from services.shared.schemas import MAJOR_DECISION_CATALYSTS
|
||||
|
||||
|
||||
def _pattern_to_dict(p) -> dict[str, Any]:
|
||||
"""Convert a HistoricalPattern dataclass to a JSON-safe dict."""
|
||||
d = asdict(p)
|
||||
for key, val in d.items():
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
|
||||
@app.get("/api/patterns/{ticker}")
|
||||
async def get_patterns_for_ticker(
|
||||
ticker: str,
|
||||
catalyst_type: Optional[str] = None,
|
||||
time_horizon: Optional[str] = None,
|
||||
):
|
||||
"""Historical patterns for a company.
|
||||
|
||||
Filterable by catalyst_type and time_horizon.
|
||||
Returns sample_count, outcome distribution, pattern_confidence,
|
||||
and date range for each pattern.
|
||||
|
||||
Requirements: 10.1, 10.3
|
||||
"""
|
||||
horizons = [time_horizon] if time_horizon else None
|
||||
|
||||
if catalyst_type:
|
||||
patterns = await find_self_patterns(pool, ticker, catalyst_type, horizons=horizons)
|
||||
else:
|
||||
# Query across all catalyst types present in the company's history
|
||||
rows = await pool.fetch(
|
||||
"""SELECT DISTINCT di.catalyst_type
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.document_id = dir.document_id
|
||||
JOIN documents d ON d.id = dir.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
AND di.catalyst_type IS NOT NULL""",
|
||||
ticker,
|
||||
)
|
||||
patterns = []
|
||||
for row in rows:
|
||||
ct = row["catalyst_type"]
|
||||
patterns.extend(await find_self_patterns(pool, ticker, ct, horizons=horizons))
|
||||
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"patterns": [_pattern_to_dict(p) for p in patterns],
|
||||
"count": len(patterns),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/patterns/{ticker}/competitors")
|
||||
async def get_competitor_patterns(
|
||||
ticker: str,
|
||||
catalyst_type: Optional[str] = None,
|
||||
time_horizon: Optional[str] = None,
|
||||
):
|
||||
"""Cross-company patterns showing how this company's catalysts affected competitors.
|
||||
|
||||
Requirements: 10.2, 10.3
|
||||
"""
|
||||
horizons = [time_horizon] if time_horizon else None
|
||||
|
||||
# Find active competitors for this ticker
|
||||
comp_rows = await pool.fetch(
|
||||
"""SELECT DISTINCT
|
||||
CASE WHEN ca.ticker = $1 THEN cb.ticker ELSE ca.ticker END AS competitor_ticker
|
||||
FROM competitor_relationships cr
|
||||
JOIN companies ca ON ca.id = cr.company_a_id
|
||||
JOIN companies cb ON cb.id = cr.company_b_id
|
||||
WHERE cr.active = TRUE
|
||||
AND (ca.ticker = $1 OR cb.ticker = $1)""",
|
||||
ticker,
|
||||
)
|
||||
|
||||
# Determine catalyst types to query
|
||||
if catalyst_type:
|
||||
catalyst_types = [catalyst_type]
|
||||
else:
|
||||
ct_rows = await pool.fetch(
|
||||
"""SELECT DISTINCT di.catalyst_type
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.document_id = dir.document_id
|
||||
JOIN documents d ON d.id = dir.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
AND di.catalyst_type IS NOT NULL""",
|
||||
ticker,
|
||||
)
|
||||
catalyst_types = [r["catalyst_type"] for r in ct_rows]
|
||||
|
||||
patterns = []
|
||||
for comp_row in comp_rows:
|
||||
comp_ticker = comp_row["competitor_ticker"]
|
||||
for ct in catalyst_types:
|
||||
cross = await find_cross_company_patterns(
|
||||
pool, ticker, comp_ticker, ct, horizons=horizons,
|
||||
)
|
||||
patterns.extend(cross)
|
||||
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"cross_company_patterns": [_pattern_to_dict(p) for p in patterns],
|
||||
"count": len(patterns),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/patterns/{ticker}/competitive-signals")
|
||||
async def get_competitive_signals(ticker: str):
|
||||
"""Recent competitive signals targeting this company.
|
||||
|
||||
Requirements: 10.4
|
||||
"""
|
||||
rows = await pool.fetch(
|
||||
"""SELECT id, source_document_id, source_ticker, target_ticker,
|
||||
catalyst_type, pattern_confidence, signal_direction,
|
||||
signal_strength, relationship_strength, computed_at
|
||||
FROM competitive_signal_records
|
||||
WHERE target_ticker = $1
|
||||
ORDER BY computed_at DESC
|
||||
LIMIT 100""",
|
||||
ticker,
|
||||
)
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"competitive_signals": [_row_to_dict(r) for r in rows],
|
||||
"count": len(rows),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/patterns/{ticker}/decisions")
|
||||
async def get_decision_history(
|
||||
ticker: str,
|
||||
time_horizon: Optional[str] = None,
|
||||
):
|
||||
"""Major corporate decision history with trend outcomes and pattern statistics.
|
||||
|
||||
Queries document_impact_records filtered by MAJOR_DECISION_CATALYSTS,
|
||||
joined with trend_windows for outcome data.
|
||||
|
||||
Requirements: 11.4, 11.6
|
||||
"""
|
||||
major_types = list(MAJOR_DECISION_CATALYSTS)
|
||||
horizons = [time_horizon] if time_horizon else None
|
||||
|
||||
# Fetch major decision records for this ticker
|
||||
rows = await pool.fetch(
|
||||
"""SELECT dir.id, dir.document_id, dir.ticker,
|
||||
di.catalyst_type, di.summary,
|
||||
dir.impact_score, dir.created_at,
|
||||
d.published_at
|
||||
FROM document_impact_records dir
|
||||
JOIN document_intelligence di ON di.document_id = dir.document_id
|
||||
JOIN documents d ON d.id = dir.document_id
|
||||
WHERE dir.ticker = $1
|
||||
AND di.validation_status = 'valid'
|
||||
AND d.status != 'rejected'
|
||||
AND di.catalyst_type = ANY($2)
|
||||
ORDER BY dir.created_at DESC
|
||||
LIMIT 50""",
|
||||
ticker,
|
||||
major_types,
|
||||
)
|
||||
|
||||
decisions = []
|
||||
for row in rows:
|
||||
decision = _row_to_dict(row)
|
||||
|
||||
# Fetch pattern statistics for this catalyst type
|
||||
ct = row["catalyst_type"]
|
||||
patterns = await find_self_patterns(pool, ticker, ct, horizons=horizons)
|
||||
decision["pattern_statistics"] = [_pattern_to_dict(p) for p in patterns]
|
||||
|
||||
decisions.append(decision)
|
||||
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"decisions": decisions,
|
||||
"count": len(decisions),
|
||||
}
|
||||
|
||||
@@ -0,0 +1,549 @@
|
||||
"""Event classifier module for macro news articles.
|
||||
|
||||
Classifies global/geopolitical news articles into structured GlobalEvent
|
||||
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
|
||||
existing OllamaClient for inference and retry logic.
|
||||
|
||||
Persists classification prompts, raw outputs, and final events to MinIO
|
||||
and PostgreSQL for audit and downstream interpolation.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.shared.schemas import (
|
||||
EstimatedDuration,
|
||||
ImpactType,
|
||||
ModelMetadata,
|
||||
SeverityLevel,
|
||||
)
|
||||
from services.shared.storage import upload_artifact
|
||||
|
||||
logger = logging.getLogger("event_classifier")
|
||||
|
||||
PROMPT_VERSION = "event-classification-v1"
|
||||
SCHEMA_VERSION = "1.0.0"
|
||||
|
||||
# Valid enum value sets for normalization
|
||||
_VALID_IMPACT_TYPES = frozenset(e.value for e in ImpactType)
|
||||
_VALID_SEVERITY_LEVELS = frozenset(e.value for e in SeverityLevel)
|
||||
_VALID_DURATIONS = frozenset(e.value for e in EstimatedDuration)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GlobalEvent dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class GlobalEvent:
|
||||
"""Structured classification of a macro news event.
|
||||
|
||||
Produced by the event classifier from Ollama structured output.
|
||||
"""
|
||||
|
||||
event_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
event_types: list[str] = field(default_factory=list)
|
||||
severity: str = "low"
|
||||
affected_regions: list[str] = field(default_factory=list)
|
||||
affected_sectors: list[str] = field(default_factory=list)
|
||||
affected_commodities: list[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
key_facts: list[str] = field(default_factory=list)
|
||||
estimated_duration: str = "short_term"
|
||||
confidence: float = 0.5
|
||||
source_document_id: str = ""
|
||||
model_metadata: ModelMetadata = field(default_factory=ModelMetadata)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON schema for Ollama structured output
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _EventClassificationResult:
|
||||
"""Schema definition for the Ollama event classification response.
|
||||
|
||||
Not a Pydantic model — we build the JSON schema dict directly to keep
|
||||
it self-contained and Ollama-friendly (no $refs).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_event_json_schema() -> dict[str, Any]:
|
||||
"""Return the JSON schema for Ollama structured event classification output.
|
||||
|
||||
The schema forces the model to produce all required fields explicitly.
|
||||
"""
|
||||
return {
|
||||
"type": "object",
|
||||
"required": [
|
||||
"event_types",
|
||||
"severity",
|
||||
"affected_regions",
|
||||
"affected_sectors",
|
||||
"affected_commodities",
|
||||
"summary",
|
||||
"key_facts",
|
||||
"estimated_duration",
|
||||
"confidence",
|
||||
],
|
||||
"properties": {
|
||||
"event_types": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_IMPACT_TYPES),
|
||||
},
|
||||
"description": (
|
||||
"One or more impact types this event represents. "
|
||||
"Include ALL applicable types — do not collapse to a single category."
|
||||
),
|
||||
},
|
||||
"severity": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_SEVERITY_LEVELS),
|
||||
"description": "Overall severity of the event: low, moderate, high, or critical.",
|
||||
},
|
||||
"affected_regions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"ISO 3166-1 alpha-2 country codes or region names affected. "
|
||||
"Use standard codes like US, CN, EU, GB, JP. "
|
||||
"Only include regions explicitly mentioned or clearly implied."
|
||||
),
|
||||
},
|
||||
"affected_sectors": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"GICS sector identifiers or sector names affected. "
|
||||
"Examples: Energy, Materials, Industrials, Consumer Discretionary, "
|
||||
"Consumer Staples, Health Care, Financials, Information Technology, "
|
||||
"Communication Services, Utilities, Real Estate."
|
||||
),
|
||||
},
|
||||
"affected_commodities": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Commodity identifiers affected, if applicable. "
|
||||
"Examples: crude_oil, natural_gas, gold, copper, wheat, lithium, "
|
||||
"semiconductors. Empty list if no commodities are directly affected."
|
||||
),
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": "A concise 1-3 sentence summary of the event and its market implications.",
|
||||
},
|
||||
"key_facts": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": (
|
||||
"Key facts explicitly stated in the article. "
|
||||
"Do NOT infer, speculate, or fabricate facts. "
|
||||
"Each fact must be directly supported by the text."
|
||||
),
|
||||
},
|
||||
"estimated_duration": {
|
||||
"type": "string",
|
||||
"enum": sorted(_VALID_DURATIONS),
|
||||
"description": (
|
||||
"Expected duration of market impact: "
|
||||
"short_term (days to weeks), medium_term (weeks to months), "
|
||||
"long_term (months to years)."
|
||||
),
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0.0,
|
||||
"maximum": 1.0,
|
||||
"description": (
|
||||
"Your confidence in this classification. "
|
||||
"Lower if the article is ambiguous, speculative, or lacks concrete details."
|
||||
),
|
||||
},
|
||||
},
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SYSTEM_PROMPT = """\
|
||||
You classify global news articles into structured macro event intelligence. \
|
||||
Read the article carefully and extract the event classification. \
|
||||
Return ONLY valid JSON matching the schema. No commentary, no markdown, no explanation."""
|
||||
|
||||
_ANTI_HALLUCINATION_RULES = """\
|
||||
CRITICAL RULES — read carefully:
|
||||
1. Only extract information EXPLICITLY stated in the article text.
|
||||
2. Do NOT infer, speculate, or fabricate facts, regions, sectors, or commodities.
|
||||
3. If the article mentions multiple distinct impact types, include ALL of them in event_types.
|
||||
4. For affected_regions, only include regions explicitly mentioned or clearly implied by the event.
|
||||
5. For affected_sectors, only include sectors with a clear causal link to the event.
|
||||
6. For affected_commodities, only include commodities directly referenced or obviously impacted.
|
||||
7. For key_facts, each fact must be directly supported by a specific passage in the text.
|
||||
8. If the article is vague or speculative, set confidence LOW (below 0.4).
|
||||
9. Do NOT treat journalist speculation or opinion as confirmed fact.
|
||||
10. Distinguish between announced policy and proposed/rumored policy."""
|
||||
|
||||
|
||||
def build_event_classification_prompt(text: str) -> dict[str, str]:
|
||||
"""Build system and user prompts for Ollama event classification.
|
||||
|
||||
Args:
|
||||
text: Normalized text content of the macro news article.
|
||||
|
||||
Returns:
|
||||
Dict with 'system' and 'user' prompt strings.
|
||||
"""
|
||||
user_prompt = f"""\
|
||||
Classify this global news article as a macro event. Fill every field.
|
||||
|
||||
{_ANTI_HALLUCINATION_RULES}
|
||||
|
||||
Classify the event by:
|
||||
- event_types: ALL applicable impact types (supply_disruption, demand_shift, cost_increase, \
|
||||
regulatory_pressure, currency_impact, commodity_shock, trade_barrier, geopolitical_risk)
|
||||
- severity: low, moderate, high, or critical
|
||||
- affected_regions: ISO country codes or region names
|
||||
- affected_sectors: GICS sector names
|
||||
- affected_commodities: commodity identifiers (empty list if none)
|
||||
- summary: 1-3 sentence summary of the event and market implications
|
||||
- key_facts: facts explicitly stated in the text (NO fabrication)
|
||||
- estimated_duration: short_term, medium_term, or long_term
|
||||
- confidence: 0.0-1.0 your confidence in this classification
|
||||
|
||||
--- ARTICLE TEXT ---
|
||||
{text}
|
||||
--- END ARTICLE TEXT ---"""
|
||||
|
||||
return {
|
||||
"system": _SYSTEM_PROMPT,
|
||||
"user": user_prompt,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Classification response parsing and normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _normalize_event_types(raw: list[Any]) -> list[str]:
|
||||
"""Normalize and filter event_types to valid ImpactType values."""
|
||||
result = []
|
||||
for item in raw:
|
||||
val = str(item).lower().strip()
|
||||
if val in _VALID_IMPACT_TYPES:
|
||||
result.append(val)
|
||||
return result if result else ["geopolitical_risk"]
|
||||
|
||||
|
||||
def _normalize_severity(raw: str) -> str:
|
||||
"""Normalize severity to a valid SeverityLevel value."""
|
||||
val = str(raw).lower().strip()
|
||||
return val if val in _VALID_SEVERITY_LEVELS else "low"
|
||||
|
||||
|
||||
def _normalize_duration(raw: str) -> str:
|
||||
"""Normalize estimated_duration to a valid EstimatedDuration value."""
|
||||
val = str(raw).lower().strip()
|
||||
return val if val in _VALID_DURATIONS else "short_term"
|
||||
|
||||
|
||||
def _parse_classification_response(
|
||||
raw_json: str,
|
||||
document_id: str,
|
||||
model_name: str,
|
||||
) -> GlobalEvent:
|
||||
"""Parse raw Ollama JSON output into a GlobalEvent.
|
||||
|
||||
Normalizes enum values and clamps numeric fields.
|
||||
"""
|
||||
data = json.loads(raw_json)
|
||||
|
||||
confidence = data.get("confidence", 0.5)
|
||||
if isinstance(confidence, (int, float)):
|
||||
confidence = max(0.0, min(1.0, float(confidence)))
|
||||
else:
|
||||
confidence = 0.5
|
||||
|
||||
return GlobalEvent(
|
||||
event_id=str(uuid.uuid4()),
|
||||
event_types=_normalize_event_types(data.get("event_types", [])),
|
||||
severity=_normalize_severity(data.get("severity", "low")),
|
||||
affected_regions=[str(r) for r in data.get("affected_regions", [])],
|
||||
affected_sectors=[str(s) for s in data.get("affected_sectors", [])],
|
||||
affected_commodities=[str(c) for c in data.get("affected_commodities", [])],
|
||||
summary=str(data.get("summary", "")),
|
||||
key_facts=[str(f) for f in data.get("key_facts", [])],
|
||||
estimated_duration=_normalize_duration(data.get("estimated_duration", "short_term")),
|
||||
confidence=confidence,
|
||||
source_document_id=document_id,
|
||||
model_metadata=ModelMetadata(
|
||||
provider="ollama",
|
||||
model_name=model_name,
|
||||
prompt_version=PROMPT_VERSION,
|
||||
schema_version=SCHEMA_VERSION,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MinIO persistence helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _upload_classification_prompt(
|
||||
minio_client: Minio,
|
||||
document_id: str,
|
||||
prompt_data: dict[str, str],
|
||||
model_name: str,
|
||||
timestamp: datetime | None = None,
|
||||
) -> str:
|
||||
"""Upload classification prompt and metadata to stonks-llm-prompts."""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
payload = json.dumps({
|
||||
"prompt_version": PROMPT_VERSION,
|
||||
"schema_version": SCHEMA_VERSION,
|
||||
"model": model_name,
|
||||
"system_prompt": prompt_data["system"],
|
||||
"user_prompt": prompt_data["user"],
|
||||
"json_schema": get_event_json_schema(),
|
||||
}, indent=2).encode()
|
||||
|
||||
path = (
|
||||
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/prompt.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
minio_client, "stonks-llm-prompts", path, payload,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
def _upload_classification_result(
|
||||
minio_client: Minio,
|
||||
document_id: str,
|
||||
raw_output: str,
|
||||
event: GlobalEvent | None,
|
||||
success: bool,
|
||||
error: str | None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> str:
|
||||
"""Upload raw classification output to stonks-llm-results."""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
payload = json.dumps({
|
||||
"document_id": document_id,
|
||||
"success": success,
|
||||
"error": error,
|
||||
"raw_output": raw_output,
|
||||
"parsed_event": {
|
||||
"event_id": event.event_id,
|
||||
"event_types": event.event_types,
|
||||
"severity": event.severity,
|
||||
"affected_regions": event.affected_regions,
|
||||
"affected_sectors": event.affected_sectors,
|
||||
"affected_commodities": event.affected_commodities,
|
||||
"summary": event.summary,
|
||||
"key_facts": event.key_facts,
|
||||
"estimated_duration": event.estimated_duration,
|
||||
"confidence": event.confidence,
|
||||
} if event else None,
|
||||
}, indent=2).encode()
|
||||
|
||||
path = (
|
||||
f"event_classification/macro/{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/result.json"
|
||||
)
|
||||
return upload_artifact(
|
||||
minio_client, "stonks-llm-results", path, payload,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PostgreSQL persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def persist_global_event(
|
||||
pool: asyncpg.Pool,
|
||||
event: GlobalEvent,
|
||||
) -> str:
|
||||
"""Persist a GlobalEvent record to the global_events PostgreSQL table.
|
||||
|
||||
Returns the event row UUID.
|
||||
"""
|
||||
row_id = await pool.fetchval(
|
||||
"""INSERT INTO global_events
|
||||
(id, event_types, severity, affected_regions, affected_sectors,
|
||||
affected_commodities, summary, key_facts, estimated_duration,
|
||||
confidence, source_document_id, model_provider, model_name,
|
||||
prompt_version, schema_version)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8::jsonb, $9, $10,
|
||||
$11::uuid, $12, $13, $14, $15)
|
||||
RETURNING id""",
|
||||
event.event_id,
|
||||
event.event_types,
|
||||
event.severity,
|
||||
event.affected_regions,
|
||||
event.affected_sectors,
|
||||
event.affected_commodities,
|
||||
event.summary,
|
||||
json.dumps(event.key_facts),
|
||||
event.estimated_duration,
|
||||
event.confidence,
|
||||
event.source_document_id,
|
||||
event.model_metadata.provider,
|
||||
event.model_metadata.model_name,
|
||||
event.model_metadata.prompt_version,
|
||||
event.model_metadata.schema_version,
|
||||
)
|
||||
logger.info(
|
||||
"Persisted global event %s for doc %s (severity=%s, types=%s)",
|
||||
row_id, event.source_document_id, event.severity, event.event_types,
|
||||
)
|
||||
return str(row_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main classification function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def classify_global_event(
|
||||
normalized_text: str,
|
||||
document_id: str,
|
||||
ollama_client: Any,
|
||||
*,
|
||||
pool: asyncpg.Pool | None = None,
|
||||
minio_client: Minio | None = None,
|
||||
) -> GlobalEvent:
|
||||
"""Classify a macro news article into a GlobalEvent using Ollama.
|
||||
|
||||
Uses the existing OllamaClient's streaming infrastructure with a
|
||||
dedicated event classification prompt and JSON schema. Follows the
|
||||
same retry policy as document extraction.
|
||||
|
||||
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
||||
when the respective clients are provided.
|
||||
|
||||
Args:
|
||||
normalized_text: Cleaned text content of the macro article.
|
||||
document_id: UUID of the source document.
|
||||
ollama_client: An OllamaClient instance (from services.extractor.client).
|
||||
pool: Optional asyncpg pool for PostgreSQL persistence.
|
||||
minio_client: Optional MinIO client for artifact persistence.
|
||||
|
||||
Returns:
|
||||
A GlobalEvent with the classification result.
|
||||
|
||||
Raises:
|
||||
ValueError: If classification fails after all retries.
|
||||
"""
|
||||
ts = datetime.now(timezone.utc)
|
||||
prompts = build_event_classification_prompt(normalized_text)
|
||||
json_schema = get_event_json_schema()
|
||||
model_name = ollama_client._config.model
|
||||
|
||||
# Persist prompt to MinIO
|
||||
prompt_ref = None
|
||||
if minio_client:
|
||||
try:
|
||||
prompt_ref = _upload_classification_prompt(
|
||||
minio_client, document_id, prompts, model_name, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
||||
|
||||
# Call Ollama using the client's internal _call_ollama method
|
||||
# We reuse the retry logic pattern from OllamaClient.extract()
|
||||
max_retries = ollama_client._max_retries
|
||||
last_error: str | None = None
|
||||
raw_output = ""
|
||||
|
||||
for attempt_num in range(max_retries + 1):
|
||||
attempt = await ollama_client._call_ollama(prompts, json_schema)
|
||||
raw_output = attempt.raw_output
|
||||
|
||||
if attempt.error is None and raw_output:
|
||||
# Try to parse the response
|
||||
try:
|
||||
event = _parse_classification_response(
|
||||
raw_output, document_id, model_name,
|
||||
)
|
||||
|
||||
# Persist result to MinIO
|
||||
if minio_client:
|
||||
try:
|
||||
_upload_classification_result(
|
||||
minio_client, document_id, raw_output,
|
||||
event, success=True, error=None, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to upload classification result for doc %s", document_id,
|
||||
)
|
||||
|
||||
# Persist to PostgreSQL
|
||||
if pool:
|
||||
try:
|
||||
await persist_global_event(pool, event)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to persist global event for doc %s", document_id,
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as exc:
|
||||
last_error = f"parse_error: {exc}"
|
||||
logger.warning(
|
||||
"Classification parse error for doc %s attempt %d: %s",
|
||||
document_id, attempt_num + 1, exc,
|
||||
)
|
||||
else:
|
||||
last_error = attempt.error or "empty_response"
|
||||
|
||||
# Retry with backoff
|
||||
if attempt_num < max_retries:
|
||||
delay = ollama_client._base_delay * (
|
||||
ollama_client._backoff_multiplier ** attempt_num
|
||||
)
|
||||
delay = min(delay, ollama_client._max_delay)
|
||||
logger.warning(
|
||||
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# All retries exhausted — persist failure and raise
|
||||
if minio_client:
|
||||
try:
|
||||
_upload_classification_result(
|
||||
minio_client, document_id, raw_output,
|
||||
event=None, success=False, error=last_error, timestamp=ts,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to upload failed classification result for doc %s", document_id,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Event classification failed for document {document_id} "
|
||||
f"after {max_retries + 1} attempts: {last_error}"
|
||||
)
|
||||
@@ -0,0 +1,394 @@
|
||||
"""Exposure profile auto-inference from filing extractions.
|
||||
|
||||
Infers baseline exposure profiles from company filing extractions when
|
||||
no manual profile exists. Scans recent filing extractions for geographic
|
||||
revenue breakdowns, supplier mentions, and commodity references.
|
||||
|
||||
Requirements: 9.1, 9.2, 9.3
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from services.aggregation.interpolation import build_default_profile
|
||||
from services.shared.schemas import (
|
||||
DocumentIntelligence,
|
||||
ExposureProfileSchema,
|
||||
MarketPositionTier,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("exposure_inference")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Known region patterns for geographic extraction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_REGION_KEYWORDS: dict[str, str] = {
|
||||
"united states": "US",
|
||||
"u.s.": "US",
|
||||
"us": "US",
|
||||
"america": "US",
|
||||
"north america": "US",
|
||||
"china": "CN",
|
||||
"chinese": "CN",
|
||||
"europe": "EU",
|
||||
"european": "EU",
|
||||
"eu": "EU",
|
||||
"japan": "JP",
|
||||
"japanese": "JP",
|
||||
"germany": "DE",
|
||||
"german": "DE",
|
||||
"united kingdom": "GB",
|
||||
"uk": "GB",
|
||||
"britain": "GB",
|
||||
"british": "GB",
|
||||
"south korea": "KR",
|
||||
"korea": "KR",
|
||||
"india": "IN",
|
||||
"indian": "IN",
|
||||
"brazil": "BR",
|
||||
"brazilian": "BR",
|
||||
"australia": "AU",
|
||||
"australian": "AU",
|
||||
"canada": "CA",
|
||||
"canadian": "CA",
|
||||
"taiwan": "TW",
|
||||
"saudi arabia": "SA",
|
||||
"russia": "RU",
|
||||
"russian": "RU",
|
||||
"mexico": "MX",
|
||||
"singapore": "SG",
|
||||
"asia": "CN",
|
||||
"asia pacific": "CN",
|
||||
"latin america": "BR",
|
||||
"middle east": "SA",
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Known commodity patterns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_COMMODITY_KEYWORDS: dict[str, str] = {
|
||||
"crude oil": "crude_oil",
|
||||
"oil": "crude_oil",
|
||||
"petroleum": "crude_oil",
|
||||
"natural gas": "natural_gas",
|
||||
"gas": "natural_gas",
|
||||
"copper": "copper",
|
||||
"steel": "steel",
|
||||
"lithium": "lithium",
|
||||
"semiconductor": "semiconductors",
|
||||
"semiconductors": "semiconductors",
|
||||
"chip": "semiconductors",
|
||||
"chips": "semiconductors",
|
||||
"wheat": "wheat",
|
||||
"corn": "corn",
|
||||
"gold": "gold",
|
||||
"aluminum": "aluminum",
|
||||
"aluminium": "aluminum",
|
||||
"nickel": "nickel",
|
||||
"cobalt": "cobalt",
|
||||
"rare earth": "rare_earth",
|
||||
}
|
||||
|
||||
# Minimum number of filing documents to consider inference meaningful
|
||||
_MIN_FILINGS_FOR_INFERENCE = 1
|
||||
|
||||
# Minimum total mentions to consider a region significant
|
||||
_MIN_REGION_MENTIONS = 1
|
||||
|
||||
# Minimum total mentions to consider a commodity significant
|
||||
_MIN_COMMODITY_MENTIONS = 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text scanning helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_regions_from_text(text: str) -> dict[str, int]:
|
||||
"""Extract region mentions from text, returning region_code -> count."""
|
||||
text_lower = text.lower()
|
||||
region_counts: dict[str, int] = defaultdict(int)
|
||||
|
||||
for keyword, code in _REGION_KEYWORDS.items():
|
||||
# Use word boundary matching for short keywords
|
||||
if len(keyword) <= 3:
|
||||
pattern = rf"\b{re.escape(keyword)}\b"
|
||||
matches = re.findall(pattern, text_lower)
|
||||
else:
|
||||
matches = re.findall(re.escape(keyword), text_lower)
|
||||
if matches:
|
||||
region_counts[code] += len(matches)
|
||||
|
||||
return dict(region_counts)
|
||||
|
||||
|
||||
def _extract_commodities_from_text(text: str) -> dict[str, int]:
|
||||
"""Extract commodity mentions from text, returning commodity_id -> count."""
|
||||
text_lower = text.lower()
|
||||
commodity_counts: dict[str, int] = defaultdict(int)
|
||||
|
||||
for keyword, commodity_id in _COMMODITY_KEYWORDS.items():
|
||||
if len(keyword) <= 4:
|
||||
pattern = rf"\b{re.escape(keyword)}\b"
|
||||
matches = re.findall(pattern, text_lower)
|
||||
else:
|
||||
matches = re.findall(re.escape(keyword), text_lower)
|
||||
if matches:
|
||||
commodity_counts[commodity_id] += len(matches)
|
||||
|
||||
return dict(commodity_counts)
|
||||
|
||||
|
||||
def _extract_supply_chain_regions(text: str) -> set[str]:
|
||||
"""Extract supply chain region mentions from text."""
|
||||
supply_keywords = [
|
||||
"supplier", "supply chain", "sourcing", "manufacturing",
|
||||
"factory", "plant", "warehouse", "distribution",
|
||||
"import", "export", "procurement",
|
||||
]
|
||||
text_lower = text.lower()
|
||||
|
||||
regions: set[str] = set()
|
||||
for keyword in supply_keywords:
|
||||
if keyword in text_lower:
|
||||
# Find regions mentioned near supply chain keywords
|
||||
# Look within a window around each occurrence
|
||||
for match in re.finditer(re.escape(keyword), text_lower):
|
||||
start = max(0, match.start() - 200)
|
||||
end = min(len(text_lower), match.end() + 200)
|
||||
window = text_lower[start:end]
|
||||
window_regions = _extract_regions_from_text(window)
|
||||
regions.update(window_regions.keys())
|
||||
|
||||
return regions
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Revenue mix estimation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _estimate_revenue_mix(region_counts: dict[str, int]) -> dict[str, float]:
|
||||
"""Estimate geographic revenue mix from region mention counts.
|
||||
|
||||
Uses mention frequency as a proxy for revenue distribution.
|
||||
Normalizes to sum to 1.0.
|
||||
"""
|
||||
if not region_counts:
|
||||
return {}
|
||||
|
||||
total = sum(region_counts.values())
|
||||
if total == 0:
|
||||
return {}
|
||||
|
||||
mix = {
|
||||
region: round(count / total, 4)
|
||||
for region, count in region_counts.items()
|
||||
if count >= _MIN_REGION_MENTIONS
|
||||
}
|
||||
|
||||
# Re-normalize after filtering
|
||||
mix_total = sum(mix.values())
|
||||
if mix_total > 0 and abs(mix_total - 1.0) > 0.001:
|
||||
mix = {r: round(v / mix_total, 4) for r, v in mix.items()}
|
||||
|
||||
return mix
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Confidence scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_inference_confidence(
|
||||
num_filings: int,
|
||||
num_regions: int,
|
||||
num_commodities: int,
|
||||
total_mentions: int,
|
||||
) -> float:
|
||||
"""Compute confidence score for the inferred profile.
|
||||
|
||||
Higher confidence when more filings are available and more
|
||||
geographic/commodity data points are found.
|
||||
"""
|
||||
# Base confidence from number of filings (more filings = more reliable)
|
||||
filing_factor = min(num_filings / 5.0, 1.0) # saturates at 5 filings
|
||||
|
||||
# Data richness factor
|
||||
data_points = num_regions + num_commodities
|
||||
richness_factor = min(data_points / 8.0, 1.0) # saturates at 8 data points
|
||||
|
||||
# Mention volume factor
|
||||
volume_factor = min(total_mentions / 20.0, 1.0) # saturates at 20 mentions
|
||||
|
||||
confidence = 0.4 * filing_factor + 0.35 * richness_factor + 0.25 * volume_factor
|
||||
return round(max(0.0, min(1.0, confidence)), 4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main inference function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def infer_exposure_profile(
|
||||
document_intelligences: list[DocumentIntelligence],
|
||||
sector: str,
|
||||
industry: str,
|
||||
market_cap_bucket: str,
|
||||
) -> ExposureProfileSchema:
|
||||
"""Infer a baseline exposure profile from filing extractions.
|
||||
|
||||
Scans recent filing extractions for geographic revenue breakdowns,
|
||||
supplier mentions, and commodity references. Produces an
|
||||
ExposureProfile with source='inferred' and a confidence score
|
||||
reflecting data quality.
|
||||
|
||||
Falls back to sector-based default profile when insufficient
|
||||
filing data is available.
|
||||
|
||||
Args:
|
||||
document_intelligences: List of DocumentIntelligence from recent filings.
|
||||
sector: Company's GICS sector name.
|
||||
industry: Company's industry name.
|
||||
market_cap_bucket: One of large_cap, mid_cap, small_cap, micro_cap.
|
||||
|
||||
Returns:
|
||||
An ExposureProfileSchema with source='inferred'.
|
||||
|
||||
Requirements: 9.1, 9.2, 9.3
|
||||
"""
|
||||
# Filter to filing-type documents
|
||||
filings = [
|
||||
di for di in document_intelligences
|
||||
if di.document_type.value in ("filing", "transcript")
|
||||
]
|
||||
|
||||
if len(filings) < _MIN_FILINGS_FOR_INFERENCE:
|
||||
logger.info(
|
||||
"Insufficient filing data (%d filings) for inference, "
|
||||
"falling back to sector-based default profile",
|
||||
len(filings),
|
||||
)
|
||||
return build_default_profile(sector, industry, market_cap_bucket)
|
||||
|
||||
# Aggregate region and commodity mentions across all filings
|
||||
all_region_counts: dict[str, int] = defaultdict(int)
|
||||
all_commodity_counts: dict[str, int] = defaultdict(int)
|
||||
all_supply_regions: set[str] = set()
|
||||
|
||||
for filing in filings:
|
||||
# Scan summary text
|
||||
if filing.summary:
|
||||
regions = _extract_regions_from_text(filing.summary)
|
||||
for r, c in regions.items():
|
||||
all_region_counts[r] += c
|
||||
|
||||
commodities = _extract_commodities_from_text(filing.summary)
|
||||
for com, c in commodities.items():
|
||||
all_commodity_counts[com] += c
|
||||
|
||||
supply_regions = _extract_supply_chain_regions(filing.summary)
|
||||
all_supply_regions.update(supply_regions)
|
||||
|
||||
# Scan company impacts for geographic and commodity mentions
|
||||
for company in filing.companies:
|
||||
# Key facts and evidence spans contain geographic details
|
||||
for text in company.key_facts + company.evidence_spans:
|
||||
regions = _extract_regions_from_text(text)
|
||||
for r, c in regions.items():
|
||||
all_region_counts[r] += c
|
||||
|
||||
commodities = _extract_commodities_from_text(text)
|
||||
for com, c in commodities.items():
|
||||
all_commodity_counts[com] += c
|
||||
|
||||
supply_regions = _extract_supply_chain_regions(text)
|
||||
all_supply_regions.update(supply_regions)
|
||||
|
||||
# Scan macro themes for commodity/region hints
|
||||
for theme in filing.macro_themes:
|
||||
regions = _extract_regions_from_text(theme)
|
||||
for r, c in regions.items():
|
||||
all_region_counts[r] += c
|
||||
|
||||
commodities = _extract_commodities_from_text(theme)
|
||||
for com, c in commodities.items():
|
||||
all_commodity_counts[com] += c
|
||||
|
||||
# Check if we have enough data to infer
|
||||
total_mentions = sum(all_region_counts.values()) + sum(all_commodity_counts.values())
|
||||
has_regions = len(all_region_counts) > 0
|
||||
has_commodities = len(all_commodity_counts) > 0
|
||||
|
||||
if not has_regions and not has_commodities:
|
||||
logger.info(
|
||||
"No geographic or commodity data found in %d filings, "
|
||||
"falling back to sector-based default profile",
|
||||
len(filings),
|
||||
)
|
||||
return build_default_profile(sector, industry, market_cap_bucket)
|
||||
|
||||
# Build the inferred profile
|
||||
geographic_revenue_mix = _estimate_revenue_mix(dict(all_region_counts))
|
||||
|
||||
# Filter commodities by minimum mentions
|
||||
key_commodities = [
|
||||
com for com, count in all_commodity_counts.items()
|
||||
if count >= _MIN_COMMODITY_MENTIONS
|
||||
]
|
||||
|
||||
# Supply chain regions: combine extracted supply regions with geo regions
|
||||
supply_chain_regions = list(all_supply_regions | set(geographic_revenue_mix.keys()))
|
||||
|
||||
# Market position tier from market cap bucket
|
||||
from services.aggregation.interpolation import _CAP_TO_TIER
|
||||
tier_value = _CAP_TO_TIER.get(market_cap_bucket, MarketPositionTier.REGIONAL.value)
|
||||
|
||||
# Regulatory jurisdictions: top regions by revenue
|
||||
sorted_regions = sorted(
|
||||
geographic_revenue_mix.items(), key=lambda x: x[1], reverse=True,
|
||||
)
|
||||
regulatory_jurisdictions = [r for r, _ in sorted_regions[:3]]
|
||||
|
||||
# Export dependency: fraction of revenue outside the top region
|
||||
if geographic_revenue_mix:
|
||||
top_region_pct = max(geographic_revenue_mix.values())
|
||||
export_pct = round(1.0 - top_region_pct, 4)
|
||||
else:
|
||||
export_pct = 0.0
|
||||
|
||||
# Confidence score
|
||||
confidence = _compute_inference_confidence(
|
||||
num_filings=len(filings),
|
||||
num_regions=len(all_region_counts),
|
||||
num_commodities=len(all_commodity_counts),
|
||||
total_mentions=total_mentions,
|
||||
)
|
||||
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="",
|
||||
geographic_revenue_mix=geographic_revenue_mix,
|
||||
supply_chain_regions=supply_chain_regions,
|
||||
key_input_commodities=key_commodities,
|
||||
regulatory_jurisdictions=regulatory_jurisdictions,
|
||||
market_position_tier=MarketPositionTier(tier_value),
|
||||
export_dependency_pct=max(0.0, min(1.0, export_pct)),
|
||||
source="inferred",
|
||||
confidence=confidence,
|
||||
version=1,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Inferred exposure profile: regions=%d, commodities=%d, "
|
||||
"supply_chain=%d, confidence=%.3f",
|
||||
len(geographic_revenue_mix),
|
||||
len(key_commodities),
|
||||
len(supply_chain_regions),
|
||||
confidence,
|
||||
)
|
||||
|
||||
return profile
|
||||
+234
-4
@@ -9,13 +9,21 @@ import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
from minio import Minio
|
||||
|
||||
from services.aggregation.interpolation import (
|
||||
build_default_profile,
|
||||
compute_macro_impact_with_sector,
|
||||
filter_low_confidence_events,
|
||||
persist_macro_impact_records,
|
||||
)
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.event_classifier import classify_global_event
|
||||
from services.extractor.worker import persist_extraction
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import inject_trace_context, setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_AGGREGATION,
|
||||
QUEUE_EXTRACTION,
|
||||
QUEUE_MACRO_CLASSIFICATION,
|
||||
queue_key,
|
||||
)
|
||||
|
||||
@@ -28,6 +36,198 @@ async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
|
||||
return {row["ticker"]: str(row["id"]) for row in rows}
|
||||
|
||||
|
||||
async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None:
|
||||
"""Fetch the document_type for a document."""
|
||||
row = await pool.fetchrow(
|
||||
"SELECT document_type FROM documents WHERE id = $1::uuid",
|
||||
document_id,
|
||||
)
|
||||
return row["document_type"] if row else None
|
||||
|
||||
|
||||
async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]:
|
||||
"""Fetch company info needed for exposure profile loading and interpolation."""
|
||||
rows = await pool.fetch(
|
||||
"""SELECT id, ticker, sector, industry, market_cap_bucket
|
||||
FROM companies WHERE active = TRUE"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str):
|
||||
"""Load exposure profile for a company: manual > inferred > default.
|
||||
|
||||
Requirements: 4.1
|
||||
"""
|
||||
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
|
||||
|
||||
# Try manual or inferred profile from DB
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version
|
||||
FROM exposure_profiles
|
||||
WHERE company_id = $1 AND active = TRUE
|
||||
ORDER BY version DESC LIMIT 1""",
|
||||
company_id,
|
||||
)
|
||||
if row:
|
||||
geo_mix = row["geographic_revenue_mix"]
|
||||
if isinstance(geo_mix, str):
|
||||
geo_mix = json.loads(geo_mix)
|
||||
tier_val = row["market_position_tier"]
|
||||
try:
|
||||
tier = MarketPositionTier(tier_val)
|
||||
except ValueError:
|
||||
tier = MarketPositionTier.REGIONAL
|
||||
return ExposureProfileSchema(
|
||||
company_id=str(row["company_id"]),
|
||||
geographic_revenue_mix=geo_mix or {},
|
||||
supply_chain_regions=list(row["supply_chain_regions"] or []),
|
||||
key_input_commodities=list(row["key_input_commodities"] or []),
|
||||
regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []),
|
||||
market_position_tier=tier,
|
||||
export_dependency_pct=float(row["export_dependency_pct"] or 0.0),
|
||||
source=row["source"] or "manual",
|
||||
confidence=float(row["confidence"] or 1.0),
|
||||
version=row["version"] or 1,
|
||||
)
|
||||
|
||||
# Fall back to default profile
|
||||
profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap")
|
||||
profile.company_id = str(company_id)
|
||||
return profile
|
||||
|
||||
|
||||
async def _compute_and_persist_macro_impacts(
|
||||
pool: asyncpg.Pool,
|
||||
event,
|
||||
companies: list[dict],
|
||||
confidence_threshold: float = 0.4,
|
||||
) -> list[str]:
|
||||
"""Compute MacroImpactRecords for all tracked companies and persist non-zero ones.
|
||||
|
||||
Requirements: 4.1, 4.5
|
||||
"""
|
||||
# Filter low-confidence events
|
||||
filtered = filter_low_confidence_events([event], confidence_threshold)
|
||||
if not filtered:
|
||||
logger.info("Event %s excluded: confidence %.3f below threshold %.3f",
|
||||
event.event_id, event.confidence, confidence_threshold)
|
||||
return []
|
||||
|
||||
records = []
|
||||
for company in companies:
|
||||
company_id = str(company["id"])
|
||||
ticker = company["ticker"]
|
||||
sector = company.get("sector") or ""
|
||||
industry = company.get("industry") or ""
|
||||
market_cap_bucket = company.get("market_cap_bucket") or "small_cap"
|
||||
|
||||
profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket)
|
||||
|
||||
record = compute_macro_impact_with_sector(event, profile, company_sector=sector)
|
||||
record.ticker = ticker
|
||||
record.company_id = company_id
|
||||
|
||||
if record.macro_impact_score > 0.0:
|
||||
records.append(record)
|
||||
|
||||
if records:
|
||||
ids = await persist_macro_impact_records(pool, records)
|
||||
logger.info(
|
||||
"Persisted %d macro impact records for event %s",
|
||||
len(ids), event.event_id,
|
||||
)
|
||||
return [r.ticker for r in records]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
# Track consecutive macro classification failures for alerting (Requirement 10.4)
|
||||
_macro_consecutive_failures = 0
|
||||
_MACRO_FAILURE_ALERT_THRESHOLD = 3
|
||||
|
||||
|
||||
async def _process_macro_classification(
|
||||
*,
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
ollama: OllamaClient,
|
||||
redis_client: aioredis.Redis,
|
||||
document_id: str,
|
||||
text: str,
|
||||
company_id_map: dict[str, str],
|
||||
confidence_threshold: float = 0.4,
|
||||
) -> None:
|
||||
"""Route a macro_event document to event classification, compute interpolation,
|
||||
and trigger aggregation for affected tickers.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4
|
||||
"""
|
||||
global _macro_consecutive_failures
|
||||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||||
|
||||
try:
|
||||
event = await classify_global_event(
|
||||
normalized_text=text,
|
||||
document_id=document_id,
|
||||
ollama_client=ollama,
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
)
|
||||
logger.info(
|
||||
"Classified macro event %s for doc %s: severity=%s types=%s",
|
||||
event.event_id, document_id, event.severity, event.event_types,
|
||||
)
|
||||
|
||||
# Reset failure counter on success
|
||||
_macro_consecutive_failures = 0
|
||||
|
||||
# Load all tracked companies and compute macro impacts
|
||||
companies = await _fetch_company_info(pool)
|
||||
affected_tickers = await _compute_and_persist_macro_impacts(
|
||||
pool, event, companies, confidence_threshold,
|
||||
)
|
||||
|
||||
# Trigger aggregation for affected tickers (those with non-zero impact)
|
||||
enqueued_tickers = set()
|
||||
for ticker in affected_tickers:
|
||||
if ticker not in enqueued_tickers:
|
||||
await redis_client.rpush(
|
||||
agg_queue,
|
||||
json.dumps(inject_trace_context({
|
||||
"ticker": ticker,
|
||||
"macro_event_id": event.event_id,
|
||||
})),
|
||||
)
|
||||
enqueued_tickers.add(ticker)
|
||||
|
||||
logger.info(
|
||||
"Enqueued aggregation jobs for %d affected tickers after macro event %s",
|
||||
len(enqueued_tickers), event.event_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
_macro_consecutive_failures += 1
|
||||
logger.error("Macro event classification failed for doc %s: %s", document_id, e)
|
||||
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
|
||||
logger.critical(
|
||||
"ALERT: Sustained macro classification failures (%d consecutive). "
|
||||
"Continuing with company-only signals. Operator action required.",
|
||||
_macro_consecutive_failures,
|
||||
)
|
||||
except Exception:
|
||||
_macro_consecutive_failures += 1
|
||||
logger.exception("Unexpected error classifying macro event for doc %s", document_id)
|
||||
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
|
||||
logger.critical(
|
||||
"ALERT: Sustained macro classification failures (%d consecutive). "
|
||||
"Continuing with company-only signals. Operator action required.",
|
||||
_macro_consecutive_failures,
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
|
||||
@@ -42,8 +242,10 @@ async def main() -> None:
|
||||
ollama = OllamaClient(config.ollama)
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_EXTRACTION)
|
||||
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||||
logger.info("Extractor worker started, polling %s", queue)
|
||||
confidence_threshold = config.macro.macro_confidence_threshold
|
||||
logger.info("Extractor worker started, polling %s and %s", queue, macro_queue)
|
||||
|
||||
# Pre-load company ID map (refreshed periodically)
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
@@ -51,7 +253,13 @@ async def main() -> None:
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await redis_client.lpop(queue)
|
||||
# Check macro classification queue first (priority)
|
||||
raw = await redis_client.lpop(macro_queue)
|
||||
is_macro_job = raw is not None
|
||||
|
||||
if raw is None:
|
||||
raw = await redis_client.lpop(queue)
|
||||
|
||||
if raw is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
@@ -80,13 +288,35 @@ async def main() -> None:
|
||||
except Exception as e:
|
||||
logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e)
|
||||
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
# Refresh company map every 100 jobs
|
||||
refresh_counter += 1
|
||||
if refresh_counter % 100 == 0:
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
|
||||
# Route macro_event documents to event classification (Requirement 2.1)
|
||||
doc_type = None
|
||||
if is_macro_job:
|
||||
doc_type = "macro_event"
|
||||
else:
|
||||
doc_type = await _fetch_document_type(pool, document_id)
|
||||
|
||||
if doc_type == "macro_event":
|
||||
logger.info("Routing macro_event doc %s to event classifier", document_id)
|
||||
await _process_macro_classification(
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
ollama=ollama,
|
||||
redis_client=redis_client,
|
||||
document_id=document_id,
|
||||
text=text,
|
||||
company_id_map=company_id_map,
|
||||
confidence_threshold=confidence_threshold,
|
||||
)
|
||||
continue
|
||||
|
||||
# Standard extraction pipeline for non-macro documents
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
try:
|
||||
# Pass all tracked tickers so the model can identify any mentioned companies
|
||||
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
||||
|
||||
@@ -10,6 +10,7 @@ from minio import Minio
|
||||
from services.adapters.base import AdapterResult
|
||||
from services.adapters.broker_adapter import AlpacaBrokerAdapter, TradingMode
|
||||
from services.adapters.filings_adapter import SECEdgarAdapter
|
||||
from services.adapters.macro_news_adapter import MacroNewsAdapter
|
||||
from services.adapters.market_adapter import PolygonMarketAdapter
|
||||
from services.adapters.news_adapter import PolygonNewsAdapter
|
||||
from services.adapters.web_scrape_adapter import WebScrapeAdapter
|
||||
@@ -69,11 +70,14 @@ async def process_job(
|
||||
logger.warning("No adapter for source_type=%s", source_type)
|
||||
return
|
||||
|
||||
# Macro sources may not have a company_id
|
||||
company_id = job.get("company_id")
|
||||
|
||||
# Record ingestion run
|
||||
run_id = await pool.fetchval(
|
||||
"""INSERT INTO ingestion_runs (source_id, company_id, source_type, status)
|
||||
VALUES ($1, $2, $3, 'running') RETURNING id""",
|
||||
source_id, job["company_id"], source_type,
|
||||
source_id, company_id, source_type,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -159,7 +163,7 @@ async def process_job(
|
||||
|
||||
# Link duplicate documents to this company if not already linked
|
||||
company_id = job.get("company_id")
|
||||
if company_id and deduped_count:
|
||||
if company_id and deduped_count and source_type not in ("macro_news",):
|
||||
from services.shared.metadata import persist_document_company_mention
|
||||
for dup in dup_items:
|
||||
existing_id = dup.get("_dedupe_existing_id")
|
||||
@@ -234,6 +238,9 @@ async def main():
|
||||
mode=TradingMode.LIVE if cfg.broker.mode == "live" else TradingMode.PAPER,
|
||||
base_url=cfg.broker.base_url,
|
||||
),
|
||||
"macro_news": MacroNewsAdapter(
|
||||
api_key=cfg.market_data.api_key,
|
||||
),
|
||||
}
|
||||
|
||||
logger.info("Ingestion worker started")
|
||||
|
||||
@@ -124,6 +124,27 @@ TABLE_SCHEMAS: dict[str, pa.Schema] = {
|
||||
"model_performance": MODEL_PERFORMANCE_SCHEMA,
|
||||
}
|
||||
|
||||
# Lazily register schemas defined in worker.py to avoid circular imports.
|
||||
# These are added after the initial dict definition.
|
||||
def _register_worker_schemas() -> None:
|
||||
from services.lake_publisher.worker import (
|
||||
COMPETITOR_RELATIONSHIPS_SCHEMA,
|
||||
COMPETITIVE_SIGNALS_SCHEMA,
|
||||
GLOBAL_EVENTS_SCHEMA,
|
||||
MACRO_IMPACTS_SCHEMA,
|
||||
TREND_PROJECTIONS_SCHEMA,
|
||||
)
|
||||
TABLE_SCHEMAS["competitor_relationships"] = COMPETITOR_RELATIONSHIPS_SCHEMA
|
||||
TABLE_SCHEMAS["competitive_signals"] = COMPETITIVE_SIGNALS_SCHEMA
|
||||
TABLE_SCHEMAS["global_events"] = GLOBAL_EVENTS_SCHEMA
|
||||
TABLE_SCHEMAS["macro_impacts"] = MACRO_IMPACTS_SCHEMA
|
||||
TABLE_SCHEMAS["trend_projections"] = TREND_PROJECTIONS_SCHEMA
|
||||
|
||||
try:
|
||||
_register_worker_schemas()
|
||||
except ImportError:
|
||||
pass # worker.py not available in minimal test environments
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IcebergTableDef:
|
||||
|
||||
@@ -39,12 +39,17 @@ from services.lake_publisher.worker import (
|
||||
publish_document_extractions_batch,
|
||||
publish_document_fact,
|
||||
publish_documents_batch,
|
||||
publish_global_event_fact,
|
||||
publish_macro_impact_fact,
|
||||
publish_market_bar,
|
||||
publish_market_quote,
|
||||
publish_pnl_daily,
|
||||
publish_positions_daily_batch,
|
||||
publish_trade_fill,
|
||||
publish_trade_order,
|
||||
publish_trend_projection_fact,
|
||||
publish_competitor_relationship_fact,
|
||||
publish_competitive_signal_fact,
|
||||
)
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_minio, get_pg_pool, get_redis
|
||||
@@ -164,6 +169,57 @@ ORDER BY di.created_at
|
||||
LIMIT 500
|
||||
"""
|
||||
|
||||
_FETCH_GLOBAL_EVENT = """
|
||||
SELECT
|
||||
ge.id, ge.event_types, ge.severity, ge.affected_regions,
|
||||
ge.affected_sectors, ge.affected_commodities, ge.summary,
|
||||
ge.estimated_duration, ge.confidence, ge.source_document_id,
|
||||
ge.created_at
|
||||
FROM global_events ge
|
||||
WHERE ge.id = $1::uuid
|
||||
"""
|
||||
|
||||
_FETCH_MACRO_IMPACTS_FOR_EVENT = """
|
||||
SELECT
|
||||
mir.event_id, mir.company_id, mir.ticker,
|
||||
mir.macro_impact_score, mir.impact_direction,
|
||||
mir.contributing_factors, mir.confidence, mir.computed_at
|
||||
FROM macro_impact_records mir
|
||||
WHERE mir.event_id = $1::uuid
|
||||
"""
|
||||
|
||||
_FETCH_TREND_PROJECTION = """
|
||||
SELECT
|
||||
tp.id, tp.trend_window_id, tp.projected_direction,
|
||||
tp.projected_strength, tp.projected_confidence,
|
||||
tp.projection_horizon, tp.driving_factors,
|
||||
tp.macro_contribution_pct, tp.diverges_from_current,
|
||||
tp.computed_at,
|
||||
tw.ticker
|
||||
FROM trend_projections tp
|
||||
JOIN trend_windows tw ON tw.id = tp.trend_window_id
|
||||
WHERE tp.trend_window_id = $1::uuid
|
||||
"""
|
||||
|
||||
_FETCH_COMPETITOR_RELATIONSHIP = """
|
||||
SELECT
|
||||
cr.id, cr.company_a_id, cr.company_b_id,
|
||||
cr.relationship_type, cr.strength, cr.bidirectional,
|
||||
cr.source, cr.active, cr.created_at
|
||||
FROM competitor_relationships cr
|
||||
WHERE cr.id = $1::uuid
|
||||
"""
|
||||
|
||||
_FETCH_COMPETITIVE_SIGNALS_FOR_DOCUMENT = """
|
||||
SELECT
|
||||
csr.id, csr.source_document_id, csr.source_ticker,
|
||||
csr.target_ticker, csr.catalyst_type, csr.pattern_confidence,
|
||||
csr.signal_direction, csr.signal_strength,
|
||||
csr.relationship_strength, csr.computed_at
|
||||
FROM competitive_signal_records csr
|
||||
WHERE csr.source_document_id = $1::uuid
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job handlers — each transforms operational rows into lake facts
|
||||
@@ -510,6 +566,165 @@ async def publish_bulk_extractions_job(
|
||||
return [ref] if ref else []
|
||||
|
||||
|
||||
async def publish_global_event_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> str:
|
||||
"""Publish a global event fact from PostgreSQL to the lake."""
|
||||
row = await pool.fetchrow(_FETCH_GLOBAL_EVENT, entity_id)
|
||||
if row is None:
|
||||
logger.warning("Global event %s not found, skipping lake publish", entity_id)
|
||||
return ""
|
||||
|
||||
event_types = row["event_types"] or []
|
||||
affected_regions = row["affected_regions"] or []
|
||||
affected_sectors = row["affected_sectors"] or []
|
||||
affected_commodities = row["affected_commodities"] or []
|
||||
|
||||
return publish_global_event_fact(
|
||||
client=minio_client,
|
||||
event_id=str(row["id"]),
|
||||
event_types=list(event_types),
|
||||
severity=row["severity"] or "low",
|
||||
affected_regions=list(affected_regions),
|
||||
affected_sectors=list(affected_sectors),
|
||||
affected_commodities=list(affected_commodities),
|
||||
summary=row["summary"] or "",
|
||||
estimated_duration=row["estimated_duration"] or "short_term",
|
||||
confidence=float(row["confidence"] or 0.0),
|
||||
source_document_id=str(row["source_document_id"]) if row["source_document_id"] else "",
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
|
||||
|
||||
async def publish_macro_impacts_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> list[str]:
|
||||
"""Publish macro impact facts for a global event from PostgreSQL to the lake."""
|
||||
rows = await pool.fetch(_FETCH_MACRO_IMPACTS_FOR_EVENT, entity_id)
|
||||
if not rows:
|
||||
logger.info("No macro impact records for event %s", entity_id)
|
||||
return []
|
||||
|
||||
refs: list[str] = []
|
||||
for row in rows:
|
||||
factors = row["contributing_factors"]
|
||||
if isinstance(factors, str):
|
||||
try:
|
||||
factors = json.loads(factors)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
factors = [factors] if factors else []
|
||||
elif factors is None:
|
||||
factors = []
|
||||
|
||||
ref = publish_macro_impact_fact(
|
||||
client=minio_client,
|
||||
event_id=str(row["event_id"]),
|
||||
company_id=str(row["company_id"]),
|
||||
ticker=row["ticker"],
|
||||
macro_impact_score=float(row["macro_impact_score"] or 0.0),
|
||||
impact_direction=row["impact_direction"] or "neutral",
|
||||
contributing_factors=list(factors),
|
||||
confidence=float(row["confidence"] or 0.0),
|
||||
computed_at=row["computed_at"],
|
||||
)
|
||||
refs.append(ref)
|
||||
return refs
|
||||
|
||||
|
||||
async def publish_trend_projection_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> str:
|
||||
"""Publish a trend projection fact from PostgreSQL to the lake."""
|
||||
row = await pool.fetchrow(_FETCH_TREND_PROJECTION, entity_id)
|
||||
if row is None:
|
||||
logger.warning("Trend projection for window %s not found", entity_id)
|
||||
return ""
|
||||
|
||||
factors = row["driving_factors"]
|
||||
if isinstance(factors, str):
|
||||
try:
|
||||
factors = json.loads(factors)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
factors = [factors] if factors else []
|
||||
elif factors is None:
|
||||
factors = []
|
||||
|
||||
return publish_trend_projection_fact(
|
||||
client=minio_client,
|
||||
trend_window_id=str(row["trend_window_id"]),
|
||||
ticker=row["ticker"] or "",
|
||||
projected_direction=row["projected_direction"] or "neutral",
|
||||
projected_strength=float(row["projected_strength"] or 0.0),
|
||||
projected_confidence=float(row["projected_confidence"] or 0.0),
|
||||
projection_horizon=row["projection_horizon"] or "7d",
|
||||
driving_factors=list(factors),
|
||||
macro_contribution_pct=float(row["macro_contribution_pct"] or 0.0),
|
||||
diverges_from_current=bool(row["diverges_from_current"]),
|
||||
computed_at=row["computed_at"],
|
||||
)
|
||||
|
||||
|
||||
async def publish_competitor_relationship_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> str:
|
||||
"""Publish a competitor relationship fact from PostgreSQL to the lake."""
|
||||
row = await pool.fetchrow(_FETCH_COMPETITOR_RELATIONSHIP, entity_id)
|
||||
if row is None:
|
||||
logger.warning("Competitor relationship %s not found, skipping lake publish", entity_id)
|
||||
return ""
|
||||
|
||||
return publish_competitor_relationship_fact(
|
||||
client=minio_client,
|
||||
relationship_id=str(row["id"]),
|
||||
company_a_id=str(row["company_a_id"]),
|
||||
company_b_id=str(row["company_b_id"]),
|
||||
relationship_type=row["relationship_type"],
|
||||
strength=float(row["strength"]),
|
||||
bidirectional=bool(row["bidirectional"]),
|
||||
source=row["source"],
|
||||
active=bool(row["active"]),
|
||||
created_at=row["created_at"],
|
||||
)
|
||||
|
||||
|
||||
async def publish_competitive_signals_job(
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
entity_id: str,
|
||||
) -> list[str]:
|
||||
"""Publish competitive signal facts for a document from PostgreSQL to the lake."""
|
||||
rows = await pool.fetch(_FETCH_COMPETITIVE_SIGNALS_FOR_DOCUMENT, entity_id)
|
||||
if not rows:
|
||||
logger.info("No competitive signals for document %s", entity_id)
|
||||
return []
|
||||
|
||||
refs: list[str] = []
|
||||
for row in rows:
|
||||
ref = publish_competitive_signal_fact(
|
||||
client=minio_client,
|
||||
signal_id=str(row["id"]),
|
||||
source_document_id=str(row["source_document_id"]),
|
||||
source_ticker=row["source_ticker"],
|
||||
target_ticker=row["target_ticker"],
|
||||
catalyst_type=row["catalyst_type"],
|
||||
pattern_confidence=float(row["pattern_confidence"]),
|
||||
signal_direction=row["signal_direction"],
|
||||
signal_strength=float(row["signal_strength"]),
|
||||
relationship_strength=float(row["relationship_strength"]),
|
||||
computed_at=row["computed_at"],
|
||||
)
|
||||
refs.append(ref)
|
||||
return refs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Job dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -525,6 +740,11 @@ JOB_TYPES = {
|
||||
"company_event",
|
||||
"bulk_documents",
|
||||
"bulk_extractions",
|
||||
"global_event",
|
||||
"macro_impact",
|
||||
"trend_projection",
|
||||
"competitor_relationship",
|
||||
"competitive_signal",
|
||||
}
|
||||
|
||||
|
||||
@@ -594,6 +814,26 @@ async def dispatch_job(
|
||||
refs = await publish_bulk_extractions_job(pool, minio_client, since)
|
||||
result["refs"] = refs
|
||||
|
||||
elif job_type == "global_event":
|
||||
ref = await publish_global_event_job(pool, minio_client, entity_id)
|
||||
result["refs"] = [ref] if ref else []
|
||||
|
||||
elif job_type == "macro_impact":
|
||||
refs = await publish_macro_impacts_job(pool, minio_client, entity_id)
|
||||
result["refs"] = refs
|
||||
|
||||
elif job_type == "trend_projection":
|
||||
ref = await publish_trend_projection_job(pool, minio_client, entity_id)
|
||||
result["refs"] = [ref] if ref else []
|
||||
|
||||
elif job_type == "competitor_relationship":
|
||||
ref = await publish_competitor_relationship_job(pool, minio_client, entity_id)
|
||||
result["refs"] = [ref] if ref else []
|
||||
|
||||
elif job_type == "competitive_signal":
|
||||
refs = await publish_competitive_signals_job(pool, minio_client, entity_id)
|
||||
result["refs"] = refs
|
||||
|
||||
else:
|
||||
result["error"] = f"Unknown job_type: {job_type}"
|
||||
logger.warning("Unknown lake publish job type: %s", job_type)
|
||||
|
||||
@@ -55,6 +55,11 @@ TABLE_PARTITIONS: dict[str, PartitionSpec] = {
|
||||
"pnl_daily": PartitionSpec("pnl_daily"),
|
||||
"prediction_vs_outcome": PartitionSpec("prediction_vs_outcome", extra_keys=("model_version",)),
|
||||
"model_performance": PartitionSpec("model_performance", extra_keys=("model_version",)),
|
||||
"global_events": PartitionSpec("global_events"),
|
||||
"macro_impacts": PartitionSpec("macro_impacts", extra_keys=("ticker",)),
|
||||
"trend_projections": PartitionSpec("trend_projections", extra_keys=("ticker",)),
|
||||
"competitor_relationships": PartitionSpec("competitor_relationships"),
|
||||
"competitive_signals": PartitionSpec("competitive_signals", extra_keys=("target_ticker",)),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1226,3 +1226,373 @@ def publish_prediction_vs_outcome_batch(
|
||||
) -> str:
|
||||
"""Publish a batch of prediction vs outcome rows as a single Parquet file."""
|
||||
return _publish_batch(client, "prediction_vs_outcome", rows, PREDICTION_VS_OUTCOME_SCHEMA, dt)
|
||||
|
||||
|
||||
# --- global_events fact table ---
|
||||
|
||||
GLOBAL_EVENTS_SCHEMA = pa.schema([
|
||||
("event_id", pa.string()),
|
||||
("event_types", pa.string()),
|
||||
("severity", pa.string()),
|
||||
("affected_regions", pa.string()),
|
||||
("affected_sectors", pa.string()),
|
||||
("affected_commodities", pa.string()),
|
||||
("summary", pa.string()),
|
||||
("estimated_duration", pa.string()),
|
||||
("confidence", pa.float64()),
|
||||
("source_document_id", pa.string()),
|
||||
("created_at", pa.timestamp("us", tz="UTC")),
|
||||
("dt", pa.date32()),
|
||||
])
|
||||
|
||||
|
||||
def publish_global_event_fact(
|
||||
client: Minio,
|
||||
event_id: str,
|
||||
event_types: list[str],
|
||||
severity: str,
|
||||
affected_regions: list[str],
|
||||
affected_sectors: list[str],
|
||||
affected_commodities: list[str],
|
||||
summary: str,
|
||||
estimated_duration: str,
|
||||
confidence: float,
|
||||
source_document_id: str,
|
||||
created_at: datetime,
|
||||
) -> str:
|
||||
"""Publish a single global event fact to MinIO.
|
||||
|
||||
Writes a Parquet file to:
|
||||
s3://stonks-lakehouse/warehouse/global_events/dt={date}/part-{uuid}.parquet
|
||||
|
||||
Returns the s3:// URI of the written object.
|
||||
|
||||
Requirements: 7.3, 12.6
|
||||
Design ref: Analytical Lake Datasets (lake.global_events)
|
||||
"""
|
||||
row: dict[str, object] = {
|
||||
"event_id": event_id,
|
||||
"event_types": ", ".join(event_types),
|
||||
"severity": severity,
|
||||
"affected_regions": ", ".join(affected_regions),
|
||||
"affected_sectors": ", ".join(affected_sectors),
|
||||
"affected_commodities": ", ".join(affected_commodities),
|
||||
"summary": summary,
|
||||
"estimated_duration": estimated_duration,
|
||||
"confidence": confidence,
|
||||
"source_document_id": source_document_id,
|
||||
"created_at": created_at,
|
||||
**partition_values(created_at),
|
||||
}
|
||||
table = pa.Table.from_pylist([row], schema=GLOBAL_EVENTS_SCHEMA)
|
||||
parquet_bytes = _write_parquet_bytes(table)
|
||||
|
||||
path = _partition_path("global_events", created_at)
|
||||
_put_lakehouse_object(client, "global_events", path, parquet_bytes)
|
||||
|
||||
ref = s3_uri(path)
|
||||
logger.info("Published global_event fact %s: %s", event_id, ref)
|
||||
return ref
|
||||
|
||||
|
||||
# --- macro_impacts fact table ---
|
||||
|
||||
MACRO_IMPACTS_SCHEMA = pa.schema([
|
||||
("event_id", pa.string()),
|
||||
("company_id", pa.string()),
|
||||
("ticker", pa.string()),
|
||||
("macro_impact_score", pa.float64()),
|
||||
("impact_direction", pa.string()),
|
||||
("contributing_factors", pa.string()),
|
||||
("confidence", pa.float64()),
|
||||
("computed_at", pa.timestamp("us", tz="UTC")),
|
||||
("dt", pa.date32()),
|
||||
])
|
||||
|
||||
|
||||
def publish_macro_impact_fact(
|
||||
client: Minio,
|
||||
event_id: str,
|
||||
company_id: str,
|
||||
ticker: str,
|
||||
macro_impact_score: float,
|
||||
impact_direction: str,
|
||||
contributing_factors: list[str],
|
||||
confidence: float,
|
||||
computed_at: datetime,
|
||||
) -> str:
|
||||
"""Publish a single macro impact fact to MinIO.
|
||||
|
||||
Writes a Parquet file to:
|
||||
s3://stonks-lakehouse/warehouse/macro_impacts/dt={date}/ticker={ticker}/part-{uuid}.parquet
|
||||
|
||||
Returns the s3:// URI of the written object.
|
||||
|
||||
Requirements: 7.3, 12.6
|
||||
Design ref: Analytical Lake Datasets (lake.macro_impacts)
|
||||
"""
|
||||
extra = {"ticker": ticker}
|
||||
row: dict[str, object] = {
|
||||
"event_id": event_id,
|
||||
"company_id": company_id,
|
||||
"ticker": ticker,
|
||||
"macro_impact_score": macro_impact_score,
|
||||
"impact_direction": impact_direction,
|
||||
"contributing_factors": ", ".join(contributing_factors),
|
||||
"confidence": confidence,
|
||||
"computed_at": computed_at,
|
||||
**partition_values(computed_at, extra),
|
||||
}
|
||||
table = pa.Table.from_pylist([row], schema=MACRO_IMPACTS_SCHEMA)
|
||||
parquet_bytes = _write_parquet_bytes(table)
|
||||
|
||||
path = _partition_path("macro_impacts", computed_at, extra_partitions=extra)
|
||||
_put_lakehouse_object(client, "macro_impacts", path, parquet_bytes)
|
||||
|
||||
ref = s3_uri(path)
|
||||
logger.info("Published macro_impact fact for %s/%s: %s", ticker, event_id, ref)
|
||||
return ref
|
||||
|
||||
|
||||
# --- trend_projections fact table ---
|
||||
|
||||
TREND_PROJECTIONS_SCHEMA = pa.schema([
|
||||
("trend_window_id", pa.string()),
|
||||
("ticker", pa.string()),
|
||||
("projected_direction", pa.string()),
|
||||
("projected_strength", pa.float64()),
|
||||
("projected_confidence", pa.float64()),
|
||||
("projection_horizon", pa.string()),
|
||||
("driving_factors", pa.string()),
|
||||
("macro_contribution_pct", pa.float64()),
|
||||
("diverges_from_current", pa.bool_()),
|
||||
("computed_at", pa.timestamp("us", tz="UTC")),
|
||||
("dt", pa.date32()),
|
||||
])
|
||||
|
||||
|
||||
def publish_trend_projection_fact(
|
||||
client: Minio,
|
||||
trend_window_id: str,
|
||||
ticker: str,
|
||||
projected_direction: str,
|
||||
projected_strength: float,
|
||||
projected_confidence: float,
|
||||
projection_horizon: str,
|
||||
driving_factors: list[str],
|
||||
macro_contribution_pct: float,
|
||||
diverges_from_current: bool,
|
||||
computed_at: datetime,
|
||||
) -> str:
|
||||
"""Publish a single trend projection fact to MinIO.
|
||||
|
||||
Writes a Parquet file to:
|
||||
s3://stonks-lakehouse/warehouse/trend_projections/dt={date}/ticker={ticker}/part-{uuid}.parquet
|
||||
|
||||
Returns the s3:// URI of the written object.
|
||||
|
||||
Requirements: 7.3, 12.6
|
||||
Design ref: Analytical Lake Datasets (lake.trend_projections)
|
||||
"""
|
||||
extra = {"ticker": ticker}
|
||||
row: dict[str, object] = {
|
||||
"trend_window_id": trend_window_id,
|
||||
"ticker": ticker,
|
||||
"projected_direction": projected_direction,
|
||||
"projected_strength": projected_strength,
|
||||
"projected_confidence": projected_confidence,
|
||||
"projection_horizon": projection_horizon,
|
||||
"driving_factors": ", ".join(driving_factors),
|
||||
"macro_contribution_pct": macro_contribution_pct,
|
||||
"diverges_from_current": diverges_from_current,
|
||||
"computed_at": computed_at,
|
||||
**partition_values(computed_at, extra),
|
||||
}
|
||||
table = pa.Table.from_pylist([row], schema=TREND_PROJECTIONS_SCHEMA)
|
||||
parquet_bytes = _write_parquet_bytes(table)
|
||||
|
||||
path = _partition_path("trend_projections", computed_at, extra_partitions=extra)
|
||||
_put_lakehouse_object(client, "trend_projections", path, parquet_bytes)
|
||||
|
||||
ref = s3_uri(path)
|
||||
logger.info("Published trend_projection fact for %s: %s", ticker, ref)
|
||||
return ref
|
||||
|
||||
|
||||
# --- Batch publishers for macro fact tables ---
|
||||
|
||||
def publish_global_events_batch(
|
||||
client: Minio,
|
||||
rows: list[dict[str, object]],
|
||||
dt: datetime,
|
||||
) -> str:
|
||||
"""Publish a batch of global event rows as a single Parquet file."""
|
||||
return _publish_batch(client, "global_events", rows, GLOBAL_EVENTS_SCHEMA, dt)
|
||||
|
||||
|
||||
def publish_macro_impacts_batch(
|
||||
client: Minio,
|
||||
rows: list[dict[str, object]],
|
||||
dt: datetime,
|
||||
ticker: str = "",
|
||||
) -> str:
|
||||
"""Publish a batch of macro impact rows as a single Parquet file."""
|
||||
extra = {"ticker": ticker} if ticker else None
|
||||
return _publish_batch(client, "macro_impacts", rows, MACRO_IMPACTS_SCHEMA, dt, extra)
|
||||
|
||||
|
||||
def publish_trend_projections_batch(
|
||||
client: Minio,
|
||||
rows: list[dict[str, object]],
|
||||
dt: datetime,
|
||||
ticker: str = "",
|
||||
) -> str:
|
||||
"""Publish a batch of trend projection rows as a single Parquet file."""
|
||||
extra = {"ticker": ticker} if ticker else None
|
||||
return _publish_batch(client, "trend_projections", rows, TREND_PROJECTIONS_SCHEMA, dt, extra)
|
||||
|
||||
|
||||
# --- competitor_relationships fact table ---
|
||||
|
||||
COMPETITOR_RELATIONSHIPS_SCHEMA = pa.schema([
|
||||
("id", pa.string()),
|
||||
("company_a_id", pa.string()),
|
||||
("company_b_id", pa.string()),
|
||||
("relationship_type", pa.string()),
|
||||
("strength", pa.float64()),
|
||||
("bidirectional", pa.bool_()),
|
||||
("source", pa.string()),
|
||||
("active", pa.bool_()),
|
||||
("created_at", pa.timestamp("us", tz="UTC")),
|
||||
("dt", pa.date32()),
|
||||
])
|
||||
|
||||
|
||||
def publish_competitor_relationship_fact(
|
||||
client: Minio,
|
||||
relationship_id: str,
|
||||
company_a_id: str,
|
||||
company_b_id: str,
|
||||
relationship_type: str,
|
||||
strength: float,
|
||||
bidirectional: bool,
|
||||
source: str,
|
||||
active: bool,
|
||||
created_at: datetime,
|
||||
) -> str:
|
||||
"""Publish a single competitor relationship fact to MinIO.
|
||||
|
||||
Writes a Parquet file to:
|
||||
s3://stonks-lakehouse/warehouse/competitor_relationships/dt={date}/part-{uuid}.parquet
|
||||
|
||||
Returns the s3:// URI of the written object.
|
||||
|
||||
Requirements: 7.3
|
||||
Design ref: Analytical Lake Datasets (lake.competitor_relationships)
|
||||
"""
|
||||
row: dict[str, object] = {
|
||||
"id": relationship_id,
|
||||
"company_a_id": company_a_id,
|
||||
"company_b_id": company_b_id,
|
||||
"relationship_type": relationship_type,
|
||||
"strength": strength,
|
||||
"bidirectional": bidirectional,
|
||||
"source": source,
|
||||
"active": active,
|
||||
"created_at": created_at,
|
||||
**partition_values(created_at),
|
||||
}
|
||||
table = pa.Table.from_pylist([row], schema=COMPETITOR_RELATIONSHIPS_SCHEMA)
|
||||
parquet_bytes = _write_parquet_bytes(table)
|
||||
|
||||
path = _partition_path("competitor_relationships", created_at)
|
||||
_put_lakehouse_object(client, "competitor_relationships", path, parquet_bytes)
|
||||
|
||||
ref = s3_uri(path)
|
||||
logger.info("Published competitor_relationship fact %s: %s", relationship_id, ref)
|
||||
return ref
|
||||
|
||||
|
||||
def publish_competitor_relationships_batch(
|
||||
client: Minio,
|
||||
rows: list[dict[str, object]],
|
||||
dt: datetime,
|
||||
) -> str:
|
||||
"""Publish a batch of competitor relationship rows as a single Parquet file."""
|
||||
return _publish_batch(client, "competitor_relationships", rows, COMPETITOR_RELATIONSHIPS_SCHEMA, dt)
|
||||
|
||||
|
||||
# --- competitive_signals fact table ---
|
||||
|
||||
COMPETITIVE_SIGNALS_SCHEMA = pa.schema([
|
||||
("id", pa.string()),
|
||||
("source_document_id", pa.string()),
|
||||
("source_ticker", pa.string()),
|
||||
("target_ticker", pa.string()),
|
||||
("catalyst_type", pa.string()),
|
||||
("pattern_confidence", pa.float64()),
|
||||
("signal_direction", pa.string()),
|
||||
("signal_strength", pa.float64()),
|
||||
("relationship_strength", pa.float64()),
|
||||
("computed_at", pa.timestamp("us", tz="UTC")),
|
||||
("dt", pa.date32()),
|
||||
])
|
||||
|
||||
|
||||
def publish_competitive_signal_fact(
|
||||
client: Minio,
|
||||
signal_id: str,
|
||||
source_document_id: str,
|
||||
source_ticker: str,
|
||||
target_ticker: str,
|
||||
catalyst_type: str,
|
||||
pattern_confidence: float,
|
||||
signal_direction: str,
|
||||
signal_strength: float,
|
||||
relationship_strength: float,
|
||||
computed_at: datetime,
|
||||
) -> str:
|
||||
"""Publish a single competitive signal fact to MinIO.
|
||||
|
||||
Writes a Parquet file to:
|
||||
s3://stonks-lakehouse/warehouse/competitive_signals/dt={date}/target_ticker={ticker}/part-{uuid}.parquet
|
||||
|
||||
Returns the s3:// URI of the written object.
|
||||
|
||||
Requirements: 7.4
|
||||
Design ref: Analytical Lake Datasets (lake.competitive_signals)
|
||||
"""
|
||||
extra = {"target_ticker": target_ticker}
|
||||
row: dict[str, object] = {
|
||||
"id": signal_id,
|
||||
"source_document_id": source_document_id,
|
||||
"source_ticker": source_ticker,
|
||||
"target_ticker": target_ticker,
|
||||
"catalyst_type": catalyst_type,
|
||||
"pattern_confidence": pattern_confidence,
|
||||
"signal_direction": signal_direction,
|
||||
"signal_strength": signal_strength,
|
||||
"relationship_strength": relationship_strength,
|
||||
"computed_at": computed_at,
|
||||
**partition_values(computed_at, extra),
|
||||
}
|
||||
table = pa.Table.from_pylist([row], schema=COMPETITIVE_SIGNALS_SCHEMA)
|
||||
parquet_bytes = _write_parquet_bytes(table)
|
||||
|
||||
path = _partition_path("competitive_signals", computed_at, extra_partitions=extra)
|
||||
_put_lakehouse_object(client, "competitive_signals", path, parquet_bytes)
|
||||
|
||||
ref = s3_uri(path)
|
||||
logger.info("Published competitive_signal fact for %s→%s: %s", source_ticker, target_ticker, ref)
|
||||
return ref
|
||||
|
||||
|
||||
def publish_competitive_signals_batch(
|
||||
client: Minio,
|
||||
rows: list[dict[str, object]],
|
||||
dt: datetime,
|
||||
target_ticker: str = "",
|
||||
) -> str:
|
||||
"""Publish a batch of competitive signal rows as a single Parquet file."""
|
||||
extra = {"target_ticker": target_ticker} if target_ticker else None
|
||||
return _publish_batch(client, "competitive_signals", rows, COMPETITIVE_SIGNALS_SCHEMA, dt, extra)
|
||||
|
||||
@@ -35,7 +35,7 @@ from services.shared.metrics import (
|
||||
PARSE_LOW_QUALITY_TOTAL,
|
||||
PARSE_QUALITY_SCORE,
|
||||
)
|
||||
from services.shared.redis_keys import QUEUE_EXTRACTION, QUEUE_PARSING, queue_key
|
||||
from services.shared.redis_keys import QUEUE_EXTRACTION, QUEUE_MACRO_CLASSIFICATION, QUEUE_PARSING, queue_key
|
||||
from services.shared.storage import upload_normalized_text, upload_parser_output
|
||||
|
||||
logger = logging.getLogger("parser_worker")
|
||||
@@ -210,7 +210,19 @@ async def process_job(
|
||||
|
||||
# Only enqueue for extraction if quality is acceptable
|
||||
if parsed.confidence != "low":
|
||||
await rds.rpush(queue_key(QUEUE_EXTRACTION), json.dumps(inject_trace_context({
|
||||
# Route macro_event documents to the macro classification queue
|
||||
# instead of the standard extraction queue (Requirement 2.1)
|
||||
doc_type_row = await pool.fetchrow(
|
||||
"SELECT document_type FROM documents WHERE id = $1::uuid", doc_id,
|
||||
)
|
||||
doc_type = doc_type_row["document_type"] if doc_type_row else None
|
||||
|
||||
if doc_type == "macro_event":
|
||||
target_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||||
else:
|
||||
target_queue = queue_key(QUEUE_EXTRACTION)
|
||||
|
||||
await rds.rpush(target_queue, json.dumps(inject_trace_context({
|
||||
"document_id": doc_id,
|
||||
"ticker": ticker,
|
||||
"normalized_text": text[:32000],
|
||||
|
||||
@@ -32,6 +32,8 @@ class SuppressionReason(str, Enum):
|
||||
LOW_SOURCE_DIVERSITY = "low_source_diversity"
|
||||
HIGH_EXTRACTION_FAILURE_RATE = "high_extraction_failure_rate"
|
||||
INSUFFICIENT_VALID_DOCUMENTS = "insufficient_valid_documents"
|
||||
MACRO_ONLY_SIGNAL = "macro_only_signal"
|
||||
PATTERN_ONLY_SIGNAL = "pattern_only_signal"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -240,3 +242,116 @@ def evaluate_suppression(
|
||||
data_quality_score=quality_score,
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Macro-only suppression (Requirements: 10.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MACRO_ONLY_CAVEAT = (
|
||||
"[Macro-only signal] This trend direction is driven solely by macro/geopolitical "
|
||||
"signals with no supporting company-specific evidence. Recommendation is "
|
||||
"informational only and should not be used for automated trading decisions."
|
||||
)
|
||||
|
||||
|
||||
def evaluate_macro_only_suppression(
|
||||
summary: TrendSummary,
|
||||
macro_signal_count: int,
|
||||
company_signal_count: int,
|
||||
) -> bool:
|
||||
"""Evaluate whether a recommendation should be suppressed due to macro-only signals.
|
||||
|
||||
When macro signals are the sole basis for a trend direction change
|
||||
(no supporting company-specific signals), the recommendation should
|
||||
be forced to informational mode with a macro-only caveat.
|
||||
|
||||
Args:
|
||||
summary: The trend summary to evaluate.
|
||||
macro_signal_count: Number of macro signals contributing to the trend.
|
||||
company_signal_count: Number of company-specific signals contributing.
|
||||
|
||||
Returns:
|
||||
True if the recommendation should be suppressed (macro-only), False otherwise.
|
||||
|
||||
Requirements: 10.3
|
||||
"""
|
||||
# No macro signals means no macro-only suppression
|
||||
if macro_signal_count <= 0:
|
||||
return False
|
||||
|
||||
# If there are company-specific signals, no suppression needed
|
||||
if company_signal_count > 0:
|
||||
return False
|
||||
|
||||
# Macro signals are the sole basis — suppress
|
||||
logger.info(
|
||||
"Macro-only suppression triggered for %s/%s: "
|
||||
"macro_signals=%d, company_signals=%d, direction=%s",
|
||||
summary.entity_id,
|
||||
summary.window.value,
|
||||
macro_signal_count,
|
||||
company_signal_count,
|
||||
summary.trend_direction.value,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern-only suppression (Requirements: 9.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PATTERN_ONLY_CAVEAT = (
|
||||
"[Pattern-only signal] This trend direction is driven solely by historical "
|
||||
"pattern and competitive signals with no supporting company-specific or macro "
|
||||
"evidence. Recommendation is informational only."
|
||||
)
|
||||
|
||||
|
||||
def evaluate_pattern_only_suppression(
|
||||
summary: TrendSummary,
|
||||
pattern_signal_count: int,
|
||||
company_signal_count: int,
|
||||
macro_signal_count: int,
|
||||
) -> bool:
|
||||
"""Evaluate whether a recommendation should be suppressed due to pattern-only signals.
|
||||
|
||||
When pattern-based signals are the sole basis for a trend direction change
|
||||
(no supporting company-specific or macro signals), the recommendation should
|
||||
be forced to informational mode with a pattern-only caveat.
|
||||
|
||||
Args:
|
||||
summary: The trend summary to evaluate.
|
||||
pattern_signal_count: Number of pattern/competitive signals contributing.
|
||||
company_signal_count: Number of company-specific signals contributing.
|
||||
macro_signal_count: Number of macro signals contributing.
|
||||
|
||||
Returns:
|
||||
True if the recommendation should be suppressed (pattern-only), False otherwise.
|
||||
|
||||
Requirements: 9.3
|
||||
"""
|
||||
# No pattern signals means no pattern-only suppression
|
||||
if pattern_signal_count <= 0:
|
||||
return False
|
||||
|
||||
# If there are company-specific signals, no suppression needed
|
||||
if company_signal_count > 0:
|
||||
return False
|
||||
|
||||
# If there are macro signals, no suppression needed
|
||||
if macro_signal_count > 0:
|
||||
return False
|
||||
|
||||
# Pattern signals are the sole basis — suppress
|
||||
logger.info(
|
||||
"Pattern-only suppression triggered for %s/%s: "
|
||||
"pattern_signals=%d, company_signals=%d, macro_signals=%d, direction=%s",
|
||||
summary.entity_id,
|
||||
summary.window.value,
|
||||
pattern_signal_count,
|
||||
company_signal_count,
|
||||
macro_signal_count,
|
||||
summary.trend_direction.value,
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -31,6 +31,7 @@ from services.recommendation.thesis_llm import (
|
||||
THESIS_PROMPT_VERSION,
|
||||
rewrite_thesis_with_llm,
|
||||
)
|
||||
from services.aggregation.projection import TrendProjection
|
||||
from services.shared.config import OllamaConfig
|
||||
from services.shared.metrics import (
|
||||
RECOMMENDATION_CONFIDENCE,
|
||||
@@ -178,6 +179,63 @@ async def fetch_latest_trend(
|
||||
return _parse_trend_row(row)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch latest trend projection for a ticker + window
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_LATEST_PROJECTION_QUERY = """
|
||||
SELECT
|
||||
tp.projected_direction, tp.projected_strength, tp.projected_confidence,
|
||||
tp.projection_horizon, tp.driving_factors, tp.macro_contribution_pct,
|
||||
tp.diverges_from_current, tp.computed_at
|
||||
FROM trend_projections tp
|
||||
JOIN trend_windows tw ON tw.id = tp.trend_window_id
|
||||
WHERE tw.entity_id = $1 AND tw."window" = $2
|
||||
ORDER BY tp.computed_at DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
async def fetch_latest_projection(
|
||||
pool: asyncpg.Pool,
|
||||
ticker: str,
|
||||
window: str,
|
||||
) -> TrendProjection | None:
|
||||
"""Fetch the most recent trend projection for a ticker and window.
|
||||
|
||||
Returns None if no projection exists. Low-confidence projections
|
||||
are returned with low_confidence=True so callers can decide whether
|
||||
to use them (Requirement 12.9).
|
||||
"""
|
||||
try:
|
||||
row = await pool.fetchrow(_LATEST_PROJECTION_QUERY, ticker, window)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
driving_factors = row["driving_factors"]
|
||||
if isinstance(driving_factors, str):
|
||||
driving_factors = json.loads(driving_factors)
|
||||
|
||||
proj = TrendProjection(
|
||||
projected_direction=row["projected_direction"],
|
||||
projected_strength=float(row["projected_strength"]),
|
||||
projected_confidence=float(row["projected_confidence"]),
|
||||
projection_horizon=row["projection_horizon"],
|
||||
driving_factors=driving_factors or [],
|
||||
macro_contribution_pct=float(row["macro_contribution_pct"] or 0.0),
|
||||
diverges_from_current=bool(row["diverges_from_current"]),
|
||||
computed_at=row["computed_at"],
|
||||
low_confidence=float(row["projected_confidence"]) < 0.3,
|
||||
)
|
||||
return proj
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch projection for %s/%s — continuing without projection",
|
||||
ticker, window, exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build thesis from trend summary (deterministic, no LLM)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -186,11 +244,16 @@ async def fetch_latest_trend(
|
||||
def build_thesis(
|
||||
summary: TrendSummary,
|
||||
result: EligibilityResult,
|
||||
projection: TrendProjection | None = None,
|
||||
) -> str:
|
||||
"""Generate a deterministic thesis string from trend data.
|
||||
|
||||
This is the descriptive analysis portion (Requirement 7.2).
|
||||
The LLM wording layer is a separate optional task.
|
||||
|
||||
When a TrendProjection is provided and is not low-confidence,
|
||||
the thesis incorporates the projected direction and key driving
|
||||
factors (Requirement 12.8).
|
||||
"""
|
||||
direction = summary.trend_direction.value
|
||||
ticker = summary.entity_id
|
||||
@@ -218,6 +281,27 @@ def build_thesis(
|
||||
+ f"(contradiction score: {summary.contradiction_score:.2f})."
|
||||
)
|
||||
|
||||
# Trend projection (Requirement 12.8)
|
||||
if projection is not None and not projection.low_confidence:
|
||||
proj_dir = projection.projected_direction
|
||||
proj_str = projection.projected_strength
|
||||
parts.append(
|
||||
f"Forward projection ({projection.projection_horizon}): "
|
||||
f"{proj_dir} at strength {proj_str:.2f}."
|
||||
)
|
||||
# Include top driving factors
|
||||
non_divergence_factors = [
|
||||
f for f in projection.driving_factors
|
||||
if not f.startswith("DIVERGENCE:")
|
||||
]
|
||||
if non_divergence_factors:
|
||||
factors_str = "; ".join(non_divergence_factors[:2])
|
||||
parts.append(f"Key drivers: {factors_str}.")
|
||||
if projection.diverges_from_current:
|
||||
parts.append(
|
||||
f"Note: projection diverges from current {direction} trend."
|
||||
)
|
||||
|
||||
# Risks
|
||||
if summary.material_risks:
|
||||
risk_str = "; ".join(summary.material_risks[:2])
|
||||
@@ -290,6 +374,7 @@ def build_recommendation(
|
||||
reference_time: datetime | None = None,
|
||||
llm_thesis: str | None = None,
|
||||
suppression_result: SuppressionResult | None = None,
|
||||
projection: TrendProjection | None = None,
|
||||
) -> Recommendation:
|
||||
"""Assemble a Recommendation object from a trend summary and eligibility result.
|
||||
|
||||
@@ -302,6 +387,10 @@ def build_recommendation(
|
||||
|
||||
If ``suppression_result`` indicates suppression, a suppression note
|
||||
is appended to the thesis for audit visibility (Requirement 7.4).
|
||||
|
||||
If ``projection`` is provided and is not low-confidence, the thesis
|
||||
incorporates projected direction and driving factors (Requirement 12.8).
|
||||
The time_horizon may be refined based on the projection horizon.
|
||||
"""
|
||||
if reference_time is None:
|
||||
reference_time = datetime.now(timezone.utc)
|
||||
@@ -309,7 +398,7 @@ def build_recommendation(
|
||||
# Combine evidence refs — supporting first, then opposing
|
||||
evidence_refs = list(summary.top_supporting_evidence) + list(summary.top_opposing_evidence)
|
||||
|
||||
deterministic_thesis = build_thesis(summary, result)
|
||||
deterministic_thesis = build_thesis(summary, result, projection=projection)
|
||||
risk_class = classify_risk(summary, result)
|
||||
|
||||
# Use LLM-rewritten thesis if available, otherwise deterministic
|
||||
@@ -324,6 +413,13 @@ def build_recommendation(
|
||||
f"reasons={', '.join(reason_strs)})]"
|
||||
)
|
||||
|
||||
# Determine time_horizon — refine with projection horizon if available
|
||||
# (Requirement 12.8)
|
||||
time_horizon = result.time_horizon
|
||||
if projection is not None and not projection.low_confidence:
|
||||
# Append projection horizon context to time_horizon
|
||||
time_horizon = f"{result.time_horizon} (proj:{projection.projection_horizon})"
|
||||
|
||||
# Track whether the thesis was LLM-generated for audit
|
||||
if llm_thesis:
|
||||
provider = "ollama"
|
||||
@@ -339,7 +435,7 @@ def build_recommendation(
|
||||
action=result.action,
|
||||
mode=result.mode,
|
||||
confidence=summary.confidence,
|
||||
time_horizon=result.time_horizon,
|
||||
time_horizon=time_horizon,
|
||||
thesis=f"[risk:{risk_class}] {thesis_body}",
|
||||
invalidation_conditions=result.invalidation_conditions,
|
||||
position_sizing=PositionSizing(
|
||||
@@ -574,12 +670,13 @@ async def generate_recommendation(
|
||||
|
||||
Steps:
|
||||
1. Fetch the latest trend summary for the ticker + window.
|
||||
2. Evaluate data quality suppression (Requirement 7.4).
|
||||
3. Evaluate eligibility using deterministic rules.
|
||||
4. Build a Recommendation object with thesis and evidence.
|
||||
2. Fetch the latest trend projection (Requirement 12.8, 12.9).
|
||||
3. Evaluate data quality suppression (Requirement 7.4).
|
||||
4. Evaluate eligibility using deterministic rules.
|
||||
5. Build a Recommendation object with thesis and evidence.
|
||||
- If ``ollama_config`` is provided, the deterministic thesis is
|
||||
rewritten into analyst-quality prose via the LLM wording layer.
|
||||
5. Persist the recommendation and evidence citations.
|
||||
6. Persist the recommendation and evidence citations.
|
||||
|
||||
Returns the Recommendation, or None if no trend data exists.
|
||||
"""
|
||||
@@ -595,13 +692,23 @@ async def generate_recommendation(
|
||||
logger.info("No trend data for %s/%s — skipping recommendation", ticker, window)
|
||||
return None
|
||||
|
||||
# 2. Evaluate data quality suppression (Requirement 7.4)
|
||||
# 2. Fetch latest trend projection (Requirement 12.8, 12.9)
|
||||
projection = await fetch_latest_projection(pool, ticker, window)
|
||||
# Exclude low-confidence projections from influencing recommendation
|
||||
# eligibility (Requirement 12.9). The projection is still passed to
|
||||
# build_recommendation for informational display, but marked as
|
||||
# low_confidence so it won't affect thesis or time_horizon.
|
||||
effective_projection = projection
|
||||
if projection is not None and projection.low_confidence:
|
||||
effective_projection = projection # still passed, but build_thesis checks low_confidence
|
||||
|
||||
# 3. Evaluate data quality suppression (Requirement 7.4)
|
||||
quality_ctx = await fetch_data_quality_context(pool, ticker, window)
|
||||
suppression = evaluate_suppression(
|
||||
summary, quality_ctx=quality_ctx, config=sup_cfg, reference_time=reference_time,
|
||||
)
|
||||
|
||||
# 3. Evaluate eligibility
|
||||
# 4. Evaluate eligibility
|
||||
result = evaluate_eligibility(summary, cfg)
|
||||
|
||||
# Apply suppression: force mode to informational if suppressed
|
||||
@@ -616,10 +723,10 @@ async def generate_recommendation(
|
||||
invalidation_conditions=result.invalidation_conditions,
|
||||
)
|
||||
|
||||
# 4. Optional LLM thesis rewrite
|
||||
# 5. Optional LLM thesis rewrite
|
||||
llm_thesis: str | None = None
|
||||
if ollama_config is not None:
|
||||
deterministic_thesis = build_thesis(summary, result)
|
||||
deterministic_thesis = build_thesis(summary, result, projection=effective_projection)
|
||||
llm_thesis = await rewrite_thesis_with_llm(
|
||||
deterministic_thesis=deterministic_thesis,
|
||||
summary=summary,
|
||||
@@ -630,13 +737,14 @@ async def generate_recommendation(
|
||||
if llm_thesis == deterministic_thesis:
|
||||
llm_thesis = None
|
||||
|
||||
# 5. Build recommendation
|
||||
# 6. Build recommendation
|
||||
rec = build_recommendation(
|
||||
summary, result, reference_time, llm_thesis=llm_thesis,
|
||||
suppression_result=suppression,
|
||||
projection=effective_projection,
|
||||
)
|
||||
|
||||
# 6. Persist recommendation, evidence citations, and risk evaluation
|
||||
# 7. Persist recommendation, evidence citations, and risk evaluation
|
||||
rec_id = await persist_recommendation(
|
||||
pool,
|
||||
rec,
|
||||
@@ -645,7 +753,7 @@ async def generate_recommendation(
|
||||
eligibility_result=result,
|
||||
)
|
||||
|
||||
# 7. Publish prediction facts to analytical tables (Requirement 9.4)
|
||||
# 8. Publish prediction facts to analytical tables (Requirement 9.4)
|
||||
if minio_client is not None:
|
||||
try:
|
||||
lake_refs = publish_recommendation_facts(
|
||||
@@ -667,10 +775,11 @@ async def generate_recommendation(
|
||||
|
||||
logger.info(
|
||||
"Generated recommendation %s for %s: action=%s mode=%s confidence=%.3f "
|
||||
"eligible=%s suppressed=%s quality_score=%.3f llm_thesis=%s",
|
||||
"eligible=%s suppressed=%s quality_score=%.3f llm_thesis=%s projection=%s",
|
||||
rec_id, ticker, rec.action.value, rec.mode.value, rec.confidence,
|
||||
result.eligible, suppression.suppressed, suppression.data_quality_score,
|
||||
llm_thesis is not None,
|
||||
projection.projected_direction if projection else "none",
|
||||
)
|
||||
|
||||
# Prometheus metrics
|
||||
|
||||
@@ -50,6 +50,7 @@ DEFAULT_CADENCES: dict[str, int] = {
|
||||
"filings_api": 3600,
|
||||
"web_scrape": 1800,
|
||||
"broker": 30,
|
||||
"macro_news": 600,
|
||||
}
|
||||
|
||||
# Default rate limits per source type (requests per minute)
|
||||
@@ -59,6 +60,7 @@ DEFAULT_RATE_LIMITS: dict[str, int] = {
|
||||
"filings_api": 10,
|
||||
"web_scrape": 10,
|
||||
"broker": 60,
|
||||
"macro_news": 10,
|
||||
}
|
||||
|
||||
# How long to wait before retrying a failed source (seconds)
|
||||
@@ -141,9 +143,9 @@ def build_job_payload(
|
||||
"""Build the ingestion job payload for a source."""
|
||||
return {
|
||||
"source_id": str(source["source_id"]),
|
||||
"company_id": str(source["company_id"]),
|
||||
"ticker": source["ticker"],
|
||||
"legal_name": source["legal_name"],
|
||||
"company_id": str(source["company_id"]) if source.get("company_id") else None,
|
||||
"ticker": source.get("ticker") or "",
|
||||
"legal_name": source.get("legal_name") or "",
|
||||
"aliases": aliases,
|
||||
"source_type": source["source_type"],
|
||||
"source_name": source["source_name"],
|
||||
@@ -183,7 +185,7 @@ async def check_rate_limit(
|
||||
|
||||
|
||||
async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||||
"""Fetch all active sources joined with their active companies."""
|
||||
"""Fetch all active company-specific sources joined with their active companies."""
|
||||
return await pool.fetch(
|
||||
"""SELECT s.id AS source_id,
|
||||
s.company_id,
|
||||
@@ -196,10 +198,33 @@ async def fetch_active_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||||
FROM sources s
|
||||
JOIN companies c ON s.company_id = c.id
|
||||
WHERE s.active = TRUE AND c.active = TRUE
|
||||
AND s.source_type != 'macro_news'
|
||||
ORDER BY s.source_type, c.ticker"""
|
||||
)
|
||||
|
||||
|
||||
async def fetch_macro_sources(pool: asyncpg.Pool) -> list[asyncpg.Record]:
|
||||
"""Fetch all active macro news sources.
|
||||
|
||||
Macro sources are not company-specific — they have source_type='macro_news'
|
||||
and may have company_id NULL. They are scheduled independently from
|
||||
company-specific sources.
|
||||
|
||||
Requirements: 1.1
|
||||
"""
|
||||
return await pool.fetch(
|
||||
"""SELECT s.id AS source_id,
|
||||
s.company_id,
|
||||
s.source_type,
|
||||
s.source_name,
|
||||
s.config,
|
||||
s.credibility_score
|
||||
FROM sources s
|
||||
WHERE s.active = TRUE AND s.source_type = 'macro_news'
|
||||
ORDER BY s.source_name"""
|
||||
)
|
||||
|
||||
|
||||
async def fetch_aliases_for_company(pool: asyncpg.Pool, company_id: str) -> list[str]:
|
||||
"""Fetch all aliases for a company."""
|
||||
rows = await pool.fetch(
|
||||
@@ -287,9 +312,57 @@ async def schedule_cycle(pool: asyncpg.Pool, rds: aioredis.Redis) -> int:
|
||||
source_type, src["ticker"], src["source_name"],
|
||||
)
|
||||
|
||||
# --- Schedule macro news sources (Requirement 1.1) ---
|
||||
macro_sources = await fetch_macro_sources(pool)
|
||||
for src in macro_sources:
|
||||
source_id = src["source_id"]
|
||||
source_type = src["source_type"]
|
||||
source_config = _ensure_dict(src["config"])
|
||||
|
||||
last_run = await fetch_last_run(pool, source_id)
|
||||
|
||||
last_completed_at = None
|
||||
last_status = None
|
||||
retry_count = 0
|
||||
next_retry_at = None
|
||||
|
||||
if last_run:
|
||||
last_status = last_run["status"]
|
||||
last_completed_at = last_run["completed_at"] or last_run["started_at"]
|
||||
retry_count = last_run["retry_count"] or 0
|
||||
next_retry_at = last_run["next_retry_at"]
|
||||
|
||||
if not is_source_due(
|
||||
source_type=source_type,
|
||||
source_config=source_config,
|
||||
last_completed_at=last_completed_at,
|
||||
last_status=last_status,
|
||||
retry_count=retry_count,
|
||||
next_retry_at=next_retry_at,
|
||||
now=now,
|
||||
):
|
||||
skipped_not_due += 1
|
||||
continue
|
||||
|
||||
if not await check_rate_limit(rds, source_type, now):
|
||||
logger.warning(
|
||||
"Rate limit hit for macro_news, skipping %s",
|
||||
src["source_name"],
|
||||
)
|
||||
skipped_rate_limit += 1
|
||||
continue
|
||||
|
||||
job = build_job_payload(src, [], now)
|
||||
await rds.rpush(queue_key(QUEUE_INGESTION), json.dumps(job))
|
||||
enqueued += 1
|
||||
|
||||
logger.debug(
|
||||
"Enqueued macro_news job for %s", src["source_name"],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Cycle complete: enqueued=%d skipped_not_due=%d skipped_rate_limit=%d total_sources=%d",
|
||||
enqueued, skipped_not_due, skipped_rate_limit, len(sources),
|
||||
enqueued, skipped_not_due, skipped_rate_limit, len(sources) + len(macro_sources),
|
||||
)
|
||||
return enqueued
|
||||
|
||||
|
||||
@@ -110,6 +110,19 @@ BUCKET_RETENTION_FIELDS: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MacroConfig:
|
||||
"""Configuration for the macro news interpolation layer.
|
||||
|
||||
Requirements: 5.6, 10.1, 10.2, 12.9
|
||||
"""
|
||||
macro_signal_weight: float = 0.3 # relative weight of macro vs company signals
|
||||
macro_enabled: bool = True # runtime toggle state (default on)
|
||||
macro_confidence_threshold: float = 0.4 # minimum confidence for event inclusion
|
||||
macro_short_term_staleness_hours: int = 48 # hours after which short-term events get accelerated decay
|
||||
projection_confidence_threshold: float = 0.3 # minimum confidence for projections to influence recommendations
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlertingConfig:
|
||||
"""Thresholds for operational alerting rules.
|
||||
@@ -135,6 +148,26 @@ class AlertingConfig:
|
||||
check_interval_seconds: int = 120
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompetitiveConfig:
|
||||
"""Configuration for the competitive intelligence & historical pattern matching layer.
|
||||
|
||||
Requirements: 5.6, 6.1, 9.1, 9.2, 11.2, 11.3
|
||||
"""
|
||||
competitive_signal_weight: float = 0.2
|
||||
competitive_enabled: bool = True
|
||||
pattern_confidence_threshold: float = 0.3
|
||||
propagation_strength_threshold: float = 0.2
|
||||
routine_lookback_days: int = 180
|
||||
major_decision_lookback_days: int = 365
|
||||
major_decision_weight_multiplier: float = 1.3
|
||||
staleness_window_days: int = 180
|
||||
staleness_recent_days: int = 90
|
||||
staleness_decay_penalty: float = 0.5
|
||||
min_pattern_samples: int = 3
|
||||
propagation_failure_threshold: int = 5 # consecutive failures before operator alert
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
postgres: PostgresConfig = field(default_factory=PostgresConfig)
|
||||
@@ -146,6 +179,8 @@ class AppConfig:
|
||||
broker: BrokerConfig = field(default_factory=BrokerConfig)
|
||||
retention: RetentionConfig = field(default_factory=RetentionConfig)
|
||||
alerting: AlertingConfig = field(default_factory=AlertingConfig)
|
||||
macro: MacroConfig = field(default_factory=MacroConfig)
|
||||
competitive: CompetitiveConfig = field(default_factory=CompetitiveConfig)
|
||||
log_level: str = "INFO"
|
||||
json_logs: bool = True
|
||||
|
||||
@@ -222,6 +257,27 @@ def load_config() -> AppConfig:
|
||||
broker_error_window_hours=int(os.getenv("ALERT_BROKER_ERROR_WINDOW_HOURS", "1")),
|
||||
check_interval_seconds=int(os.getenv("ALERT_CHECK_INTERVAL_SECONDS", "120")),
|
||||
),
|
||||
macro=MacroConfig(
|
||||
macro_signal_weight=float(os.getenv("MACRO_SIGNAL_WEIGHT", "0.3")),
|
||||
macro_enabled=os.getenv("MACRO_ENABLED", "true").lower() == "true",
|
||||
macro_confidence_threshold=float(os.getenv("MACRO_CONFIDENCE_THRESHOLD", "0.4")),
|
||||
macro_short_term_staleness_hours=int(os.getenv("MACRO_SHORT_TERM_STALENESS_HOURS", "48")),
|
||||
projection_confidence_threshold=float(os.getenv("PROJECTION_CONFIDENCE_THRESHOLD", "0.3")),
|
||||
),
|
||||
competitive=CompetitiveConfig(
|
||||
competitive_signal_weight=float(os.getenv("COMPETITIVE_SIGNAL_WEIGHT", "0.2")),
|
||||
competitive_enabled=os.getenv("COMPETITIVE_ENABLED", "true").lower() == "true",
|
||||
pattern_confidence_threshold=float(os.getenv("COMPETITIVE_PATTERN_CONFIDENCE_THRESHOLD", "0.3")),
|
||||
propagation_strength_threshold=float(os.getenv("COMPETITIVE_PROPAGATION_STRENGTH_THRESHOLD", "0.2")),
|
||||
routine_lookback_days=int(os.getenv("COMPETITIVE_ROUTINE_LOOKBACK_DAYS", "180")),
|
||||
major_decision_lookback_days=int(os.getenv("COMPETITIVE_MAJOR_DECISION_LOOKBACK_DAYS", "365")),
|
||||
major_decision_weight_multiplier=float(os.getenv("COMPETITIVE_MAJOR_DECISION_WEIGHT_MULTIPLIER", "1.3")),
|
||||
staleness_window_days=int(os.getenv("COMPETITIVE_STALENESS_WINDOW_DAYS", "180")),
|
||||
staleness_recent_days=int(os.getenv("COMPETITIVE_STALENESS_RECENT_DAYS", "90")),
|
||||
staleness_decay_penalty=float(os.getenv("COMPETITIVE_STALENESS_DECAY_PENALTY", "0.5")),
|
||||
min_pattern_samples=int(os.getenv("COMPETITIVE_MIN_PATTERN_SAMPLES", "3")),
|
||||
propagation_failure_threshold=int(os.getenv("COMPETITIVE_PROPAGATION_FAILURE_THRESHOLD", "5")),
|
||||
),
|
||||
log_level=os.getenv("LOG_LEVEL", "INFO"),
|
||||
json_logs=os.getenv("JSON_LOGS", "true").lower() == "true",
|
||||
)
|
||||
|
||||
@@ -214,6 +214,7 @@ def _resolve_document_type(source_type: str) -> str:
|
||||
"news_api": "article",
|
||||
"filings_api": "filing",
|
||||
"web_scrape": "press_release",
|
||||
"macro_news": "macro_event",
|
||||
}
|
||||
return mapping.get(source_type, "article")
|
||||
|
||||
|
||||
@@ -64,3 +64,4 @@ QUEUE_RECOMMENDATION = "recommendation"
|
||||
QUEUE_LAKE_PUBLISH = "lake_publish"
|
||||
QUEUE_TRADE = "trade"
|
||||
QUEUE_BROKER = "broker_orders"
|
||||
QUEUE_MACRO_CLASSIFICATION = "macro_classification"
|
||||
|
||||
@@ -15,6 +15,7 @@ class DocumentType(str, Enum):
|
||||
FILING = "filing"
|
||||
TRANSCRIPT = "transcript"
|
||||
PRESS_RELEASE = "press_release"
|
||||
MACRO_EVENT = "macro_event"
|
||||
|
||||
|
||||
class SourceType(str, Enum):
|
||||
@@ -71,6 +72,37 @@ class TrendWindow(str, Enum):
|
||||
NINETY_DAY = "90d"
|
||||
|
||||
|
||||
class ImpactType(str, Enum):
|
||||
SUPPLY_DISRUPTION = "supply_disruption"
|
||||
DEMAND_SHIFT = "demand_shift"
|
||||
COST_INCREASE = "cost_increase"
|
||||
REGULATORY_PRESSURE = "regulatory_pressure"
|
||||
CURRENCY_IMPACT = "currency_impact"
|
||||
COMMODITY_SHOCK = "commodity_shock"
|
||||
TRADE_BARRIER = "trade_barrier"
|
||||
GEOPOLITICAL_RISK = "geopolitical_risk"
|
||||
|
||||
|
||||
class SeverityLevel(str, Enum):
|
||||
LOW = "low"
|
||||
MODERATE = "moderate"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class MarketPositionTier(str, Enum):
|
||||
GLOBAL_LEADER = "global_leader"
|
||||
MULTINATIONAL = "multinational"
|
||||
REGIONAL = "regional"
|
||||
DOMESTIC = "domestic"
|
||||
|
||||
|
||||
class EstimatedDuration(str, Enum):
|
||||
SHORT_TERM = "short_term"
|
||||
MEDIUM_TERM = "medium_term"
|
||||
LONG_TERM = "long_term"
|
||||
|
||||
|
||||
# --- Document Intelligence ---
|
||||
|
||||
class CompanyImpact(BaseModel):
|
||||
@@ -182,6 +214,63 @@ class Recommendation(BaseModel):
|
||||
generated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# --- Global News Interpolation ---
|
||||
|
||||
class GlobalEventSchema(BaseModel):
|
||||
event_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
event_types: List[ImpactType] = Field(default_factory=list)
|
||||
severity: SeverityLevel = SeverityLevel.LOW
|
||||
affected_regions: List[str] = Field(default_factory=list)
|
||||
affected_sectors: List[str] = Field(default_factory=list)
|
||||
affected_commodities: List[str] = Field(default_factory=list)
|
||||
summary: str = ""
|
||||
key_facts: List[str] = Field(default_factory=list)
|
||||
estimated_duration: EstimatedDuration = EstimatedDuration.SHORT_TERM
|
||||
confidence: float = Field(ge=0, le=1, default=0.5)
|
||||
source_document_id: str = ""
|
||||
model_metadata: ModelMetadata = Field(default_factory=ModelMetadata)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class MacroImpactRecordSchema(BaseModel):
|
||||
event_id: str = ""
|
||||
company_id: str = ""
|
||||
ticker: str = ""
|
||||
macro_impact_score: float = Field(ge=0, le=1, default=0.0)
|
||||
impact_direction: str = "neutral"
|
||||
contributing_factors: List[str] = Field(default_factory=list)
|
||||
confidence: float = Field(ge=0, le=1, default=0.5)
|
||||
computed_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class ExposureProfileSchema(BaseModel):
|
||||
company_id: str = ""
|
||||
geographic_revenue_mix: dict[str, float] = Field(default_factory=dict)
|
||||
supply_chain_regions: List[str] = Field(default_factory=list)
|
||||
key_input_commodities: List[str] = Field(default_factory=list)
|
||||
regulatory_jurisdictions: List[str] = Field(default_factory=list)
|
||||
market_position_tier: MarketPositionTier = MarketPositionTier.REGIONAL
|
||||
export_dependency_pct: float = Field(ge=0, le=1, default=0.0)
|
||||
source: str = "manual"
|
||||
confidence: float = Field(ge=0, le=1, default=1.0)
|
||||
version: int = 1
|
||||
active: bool = True
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class TrendProjectionSchema(BaseModel):
|
||||
trend_window_id: str = ""
|
||||
projected_direction: TrendDirection = TrendDirection.NEUTRAL
|
||||
projected_strength: float = Field(ge=0, le=1, default=0.5)
|
||||
projected_confidence: float = Field(ge=0, le=1, default=0.5)
|
||||
projection_horizon: str = "7d"
|
||||
driving_factors: List[str] = Field(default_factory=list)
|
||||
macro_contribution_pct: float = Field(ge=0, le=1, default=0.0)
|
||||
diverges_from_current: bool = False
|
||||
computed_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# --- Document Metadata ---
|
||||
|
||||
class StorageRefs(BaseModel):
|
||||
@@ -204,3 +293,73 @@ class DocumentMetadata(BaseModel):
|
||||
language: str = "en"
|
||||
content_hash: str = ""
|
||||
storage_refs: StorageRefs = Field(default_factory=StorageRefs)
|
||||
|
||||
|
||||
# --- Competitive Intelligence & Historical Patterns ---
|
||||
|
||||
|
||||
class RelationshipType(str, Enum):
|
||||
DIRECT_RIVAL = "direct_rival"
|
||||
SAME_SECTOR = "same_sector"
|
||||
OVERLAPPING_PRODUCTS = "overlapping_products"
|
||||
SUPPLY_CHAIN_ADJACENT = "supply_chain_adjacent"
|
||||
|
||||
|
||||
class CatalystTier(str, Enum):
|
||||
MAJOR_CORPORATE_DECISION = "major_corporate_decision"
|
||||
ROUTINE_SIGNAL = "routine_signal"
|
||||
|
||||
|
||||
# Major corporate decision catalyst types (Req 11.1)
|
||||
MAJOR_DECISION_CATALYSTS: frozenset[str] = frozenset({
|
||||
"m_and_a",
|
||||
"legal",
|
||||
"restructuring",
|
||||
"leadership_change",
|
||||
"strategic_pivot",
|
||||
"buyback",
|
||||
"dividend_change",
|
||||
})
|
||||
|
||||
|
||||
class CompetitorRelationshipSchema(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
company_a_id: str = ""
|
||||
company_b_id: str = ""
|
||||
relationship_type: RelationshipType = RelationshipType.DIRECT_RIVAL
|
||||
strength: float = Field(ge=0, le=1, default=0.5)
|
||||
bidirectional: bool = True
|
||||
source: str = "manual"
|
||||
active: bool = True
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class CompetitiveSignalRecordSchema(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
source_document_id: str = ""
|
||||
source_ticker: str = ""
|
||||
target_ticker: str = ""
|
||||
catalyst_type: str = ""
|
||||
pattern_confidence: float = Field(ge=0, le=1, default=0.0)
|
||||
signal_direction: str = "neutral"
|
||||
signal_strength: float = Field(ge=0, le=1, default=0.0)
|
||||
relationship_strength: float = Field(ge=0, le=1, default=0.0)
|
||||
computed_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class HistoricalPatternSchema(BaseModel):
|
||||
source_ticker: str = ""
|
||||
target_ticker: str = ""
|
||||
catalyst_type: str = ""
|
||||
time_horizon: str = "7d"
|
||||
sample_count: int = 0
|
||||
bullish_pct: float = Field(ge=0, le=1, default=0.0)
|
||||
bearish_pct: float = Field(ge=0, le=1, default=0.0)
|
||||
avg_strength: float = Field(ge=0, le=1, default=0.0)
|
||||
avg_time_to_resolution: float = 0.0
|
||||
pattern_confidence: float = Field(ge=0, le=1, default=0.0)
|
||||
data_start: Optional[datetime] = None
|
||||
data_end: Optional[datetime] = None
|
||||
tier: CatalystTier = CatalystTier.ROUTINE_SIGNAL
|
||||
insufficient_data: bool = False
|
||||
|
||||
@@ -48,6 +48,7 @@ SOURCE_BUCKET_MAP: dict[str, str] = {
|
||||
"filings_api": "stonks-raw-filings",
|
||||
"web_scrape": "stonks-raw-news",
|
||||
"broker": "stonks-raw-market",
|
||||
"macro_news": "stonks-raw-news",
|
||||
}
|
||||
|
||||
# Map artifact type to content type and file extension
|
||||
@@ -75,10 +76,14 @@ def build_artifact_path(
|
||||
"""Build a MinIO object path following the design convention.
|
||||
|
||||
Pattern: {source_type}/{ticker}/{yyyy}/{mm}/{dd}/{document_id}/{artifact_name}.{ext}
|
||||
For macro_news sources, uses macro/ prefix instead of ticker:
|
||||
macro/{yyyy}/{mm}/{dd}/{document_id}/{artifact_name}.{ext}
|
||||
"""
|
||||
ts = timestamp or datetime.now(timezone.utc)
|
||||
# Macro sources use macro/ prefix instead of ticker (Requirement 1.1)
|
||||
path_prefix = "macro" if source_type == "macro_news" else f"{source_type}/{ticker}"
|
||||
return (
|
||||
f"{source_type}/{ticker}/"
|
||||
f"{path_prefix}/"
|
||||
f"{ts.year}/{ts.month:02d}/{ts.day:02d}/"
|
||||
f"{document_id}/{artifact_name}.{ext}"
|
||||
)
|
||||
|
||||
@@ -12,6 +12,9 @@ from pydantic import BaseModel, field_validator
|
||||
from services.shared.config import load_config
|
||||
from services.shared.db import get_pg_pool
|
||||
from services.shared.logging import setup_logging
|
||||
from services.symbol_registry.exposure import router as exposure_router
|
||||
from services.symbol_registry.competitors import router as competitors_router
|
||||
from services.symbol_registry.competitor_inference import router as inference_router
|
||||
|
||||
config = load_config()
|
||||
pool: Optional[asyncpg.Pool] = None
|
||||
@@ -36,6 +39,9 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
app = FastAPI(title="Stonks Oracle - Symbol Registry", lifespan=lifespan)
|
||||
app.include_router(exposure_router)
|
||||
app.include_router(competitors_router)
|
||||
app.include_router(inference_router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""Competitor auto-inference engine for the Symbol Registry API.
|
||||
|
||||
Identifies candidate competitors by sector/industry match and
|
||||
document co-mention frequency, then upserts inferred relationships.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
import asyncpg
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# --- Response Model ---
|
||||
|
||||
class CompetitorRelationship(BaseModel):
|
||||
"""Response model for a competitor relationship."""
|
||||
id: str
|
||||
company_a_id: str
|
||||
company_b_id: str
|
||||
relationship_type: str
|
||||
strength: float
|
||||
bidirectional: bool
|
||||
source: str
|
||||
active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
|
||||
"""Convert asyncpg Record to dict with UUID→str coercion."""
|
||||
d = dict(row)
|
||||
for k, v in d.items():
|
||||
if isinstance(v, uuid.UUID):
|
||||
d[k] = str(v)
|
||||
return d
|
||||
|
||||
|
||||
def _get_pool(request: Request) -> asyncpg.Pool:
|
||||
"""Get the database pool from the app module."""
|
||||
from services.symbol_registry.app import pool
|
||||
return pool
|
||||
|
||||
|
||||
async def infer_competitors(
|
||||
pool: asyncpg.Pool, company_id: str
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Infer competitor relationships based on sector/industry match and co-mentions.
|
||||
|
||||
1. Fetch target company's sector and industry.
|
||||
2. Find other active companies with the same sector AND industry.
|
||||
3. Count co-mentions in document_company_mentions for each candidate.
|
||||
4. Compute strength = 0.3 * sector_match + 0.7 * normalized_co_mention_count.
|
||||
5. Upsert relationships with source='inferred'.
|
||||
|
||||
Returns the list of upserted relationship rows.
|
||||
"""
|
||||
# Fetch target company
|
||||
target = await pool.fetchrow(
|
||||
"SELECT id, sector, industry FROM companies WHERE id = $1 AND active = TRUE",
|
||||
company_id,
|
||||
)
|
||||
if not target:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
if target["sector"] is None or target["industry"] is None:
|
||||
raise HTTPException(
|
||||
400,
|
||||
"Company must have both sector and industry defined for auto-inference",
|
||||
)
|
||||
|
||||
sector = target["sector"]
|
||||
industry = target["industry"]
|
||||
|
||||
# Find candidates: other active companies with same sector AND industry
|
||||
candidates = await pool.fetch(
|
||||
"""SELECT id FROM companies
|
||||
WHERE sector = $1 AND industry = $2 AND active = TRUE AND id != $3""",
|
||||
sector, industry, company_id,
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
candidate_ids = [r["id"] for r in candidates]
|
||||
|
||||
# Count co-mentions for each candidate
|
||||
co_mention_rows = await pool.fetch(
|
||||
"""SELECT dcm2.company_id AS candidate_id, COUNT(DISTINCT dcm1.document_id) AS co_count
|
||||
FROM document_company_mentions dcm1
|
||||
JOIN document_company_mentions dcm2
|
||||
ON dcm1.document_id = dcm2.document_id
|
||||
WHERE dcm1.company_id = $1
|
||||
AND dcm2.company_id = ANY($2::uuid[])
|
||||
GROUP BY dcm2.company_id""",
|
||||
company_id, candidate_ids,
|
||||
)
|
||||
|
||||
co_mention_map: dict[Any, int] = {}
|
||||
for row in co_mention_rows:
|
||||
co_mention_map[row["candidate_id"]] = row["co_count"]
|
||||
|
||||
# Normalize co-mention counts
|
||||
max_count = max(co_mention_map.values()) if co_mention_map else 1
|
||||
if max_count == 0:
|
||||
max_count = 1
|
||||
|
||||
# Compute strength and upsert for each candidate
|
||||
results: list[dict[str, Any]] = []
|
||||
for cid in candidate_ids:
|
||||
co_count = co_mention_map.get(cid, 0)
|
||||
normalized = co_count / max_count
|
||||
# sector_match is always 1.0 since we filter by sector+industry
|
||||
strength = 0.3 * 1.0 + 0.7 * normalized
|
||||
|
||||
# Order IDs for the unique index: LEAST/GREATEST
|
||||
a_id = min(company_id, str(cid), key=lambda x: x)
|
||||
b_id = max(company_id, str(cid), key=lambda x: x)
|
||||
|
||||
row = await pool.fetchrow(
|
||||
"""INSERT INTO competitor_relationships
|
||||
(company_a_id, company_b_id, relationship_type, strength,
|
||||
bidirectional, source)
|
||||
VALUES ($1, $2, 'same_sector', $3, TRUE, 'inferred')
|
||||
ON CONFLICT (LEAST(company_a_id, company_b_id), GREATEST(company_a_id, company_b_id))
|
||||
WHERE active = TRUE
|
||||
DO UPDATE SET strength = EXCLUDED.strength, updated_at = NOW()
|
||||
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
|
||||
bidirectional, source, active, created_at, updated_at""",
|
||||
a_id, b_id, strength,
|
||||
)
|
||||
results.append(_row_dict(row))
|
||||
|
||||
# Sort by strength descending before returning
|
||||
results.sort(key=lambda r: r["strength"], reverse=True)
|
||||
return results
|
||||
|
||||
|
||||
@router.post(
|
||||
"/companies/{company_id}/competitors/infer",
|
||||
response_model=List[CompetitorRelationship],
|
||||
)
|
||||
async def infer_competitors_endpoint(company_id: str, request: Request):
|
||||
"""Trigger auto-inference of competitor relationships for a company."""
|
||||
pool = _get_pool(request)
|
||||
return await infer_competitors(pool, company_id)
|
||||
@@ -0,0 +1,226 @@
|
||||
"""Competitor Relationship management endpoints for the Symbol Registry API."""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
import asyncpg
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from services.shared.audit import record_audit_event
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# --- Valid values ---
|
||||
VALID_RELATIONSHIP_TYPES = {"direct_rival", "same_sector", "overlapping_products", "supply_chain_adjacent"}
|
||||
VALID_SOURCES = {"manual", "inferred"}
|
||||
|
||||
|
||||
# --- Request/Response Models ---
|
||||
|
||||
class CompetitorRelationshipCreate(BaseModel):
|
||||
"""Request body for creating a competitor relationship."""
|
||||
company_b_id: str
|
||||
relationship_type: str
|
||||
strength: float = Field(default=0.5, ge=0, le=1)
|
||||
bidirectional: bool = True
|
||||
source: str = "manual"
|
||||
|
||||
@field_validator("relationship_type")
|
||||
@classmethod
|
||||
def validate_relationship_type(cls, v: str) -> str:
|
||||
if v not in VALID_RELATIONSHIP_TYPES:
|
||||
raise ValueError(f"relationship_type must be one of {VALID_RELATIONSHIP_TYPES}")
|
||||
return v
|
||||
|
||||
@field_validator("source")
|
||||
@classmethod
|
||||
def validate_source(cls, v: str) -> str:
|
||||
if v not in VALID_SOURCES:
|
||||
raise ValueError(f"source must be one of {VALID_SOURCES}")
|
||||
return v
|
||||
|
||||
|
||||
class CompetitorRelationship(BaseModel):
|
||||
"""Response model for a competitor relationship."""
|
||||
id: str
|
||||
company_a_id: str
|
||||
company_b_id: str
|
||||
relationship_type: str
|
||||
strength: float
|
||||
bidirectional: bool
|
||||
source: str
|
||||
active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
|
||||
"""Convert asyncpg Record to dict with UUID→str coercion."""
|
||||
d = dict(row)
|
||||
for k, v in d.items():
|
||||
if isinstance(v, uuid.UUID):
|
||||
d[k] = str(v)
|
||||
return d
|
||||
|
||||
|
||||
def _get_pool(request: Request) -> asyncpg.Pool:
|
||||
"""Get the database pool from the app module."""
|
||||
from services.symbol_registry.app import pool
|
||||
return pool
|
||||
|
||||
|
||||
async def _company_exists(pool: asyncpg.Pool, company_id: str) -> bool:
|
||||
"""Check if a company exists."""
|
||||
return await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id) is not None
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.post("/companies/{company_id}/competitors", response_model=CompetitorRelationship, status_code=201)
|
||||
async def create_competitor(company_id: str, body: CompetitorRelationshipCreate, request: Request):
|
||||
"""Create a competitor relationship for a company."""
|
||||
pool = _get_pool(request)
|
||||
|
||||
# Self-referencing check
|
||||
if company_id == body.company_b_id:
|
||||
raise HTTPException(400, "A company cannot be its own competitor")
|
||||
|
||||
# Check both companies exist
|
||||
if not await _company_exists(pool, company_id):
|
||||
raise HTTPException(404, "Company not found")
|
||||
if not await _company_exists(pool, body.company_b_id):
|
||||
raise HTTPException(404, "Competitor company not found")
|
||||
|
||||
try:
|
||||
row = await pool.fetchrow(
|
||||
"""INSERT INTO competitor_relationships
|
||||
(company_a_id, company_b_id, relationship_type, strength, bidirectional, source)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
|
||||
bidirectional, source, active, created_at, updated_at""",
|
||||
company_id, body.company_b_id, body.relationship_type,
|
||||
body.strength, body.bidirectional, body.source,
|
||||
)
|
||||
except asyncpg.UniqueViolationError:
|
||||
raise HTTPException(409, "An active competitor relationship already exists between these companies")
|
||||
|
||||
result = _row_dict(row)
|
||||
|
||||
await record_audit_event(
|
||||
pool,
|
||||
event_type="competitor_relationship.created",
|
||||
entity_type="competitor_relationship",
|
||||
entity_id=result["id"],
|
||||
data={
|
||||
"company_a_id": company_id,
|
||||
"company_b_id": body.company_b_id,
|
||||
"relationship_type": body.relationship_type,
|
||||
"strength": body.strength,
|
||||
"bidirectional": body.bidirectional,
|
||||
"source": body.source,
|
||||
},
|
||||
actor="operator",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/companies/{company_id}/competitors", response_model=List[CompetitorRelationship])
|
||||
async def list_competitors(company_id: str, request: Request):
|
||||
"""List active competitor relationships for a company, ordered by strength descending."""
|
||||
pool = _get_pool(request)
|
||||
|
||||
if not await _company_exists(pool, company_id):
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
rows = await pool.fetch(
|
||||
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
|
||||
bidirectional, source, active, created_at, updated_at
|
||||
FROM competitor_relationships
|
||||
WHERE (company_a_id = $1 OR company_b_id = $1) AND active = TRUE
|
||||
ORDER BY strength DESC""",
|
||||
company_id,
|
||||
)
|
||||
return [_row_dict(r) for r in rows]
|
||||
|
||||
|
||||
@router.put("/companies/{company_id}/competitors/{relationship_id}", response_model=CompetitorRelationship)
|
||||
async def update_competitor(company_id: str, relationship_id: str, body: CompetitorRelationshipCreate, request: Request):
|
||||
"""Update a competitor relationship with audit event recording previous state."""
|
||||
pool = _get_pool(request)
|
||||
|
||||
# Fetch existing relationship
|
||||
existing = await pool.fetchrow(
|
||||
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
|
||||
bidirectional, source, active, created_at, updated_at
|
||||
FROM competitor_relationships
|
||||
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2)""",
|
||||
relationship_id, company_id,
|
||||
)
|
||||
if not existing:
|
||||
raise HTTPException(404, "Competitor relationship not found")
|
||||
|
||||
previous_state = _row_dict(existing)
|
||||
|
||||
row = await pool.fetchrow(
|
||||
"""UPDATE competitor_relationships
|
||||
SET relationship_type = $2, strength = $3, bidirectional = $4, source = $5, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
|
||||
bidirectional, source, active, created_at, updated_at""",
|
||||
relationship_id, body.relationship_type, body.strength, body.bidirectional, body.source,
|
||||
)
|
||||
|
||||
result = _row_dict(row)
|
||||
|
||||
await record_audit_event(
|
||||
pool,
|
||||
event_type="competitor_relationship.updated",
|
||||
entity_type="competitor_relationship",
|
||||
entity_id=result["id"],
|
||||
data={
|
||||
"previous_state": {
|
||||
"relationship_type": previous_state["relationship_type"],
|
||||
"strength": previous_state["strength"],
|
||||
"bidirectional": previous_state["bidirectional"],
|
||||
"source": previous_state["source"],
|
||||
},
|
||||
"new_state": {
|
||||
"relationship_type": body.relationship_type,
|
||||
"strength": body.strength,
|
||||
"bidirectional": body.bidirectional,
|
||||
"source": body.source,
|
||||
},
|
||||
},
|
||||
actor="operator",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/companies/{company_id}/competitors/{relationship_id}", status_code=200)
|
||||
async def delete_competitor(company_id: str, relationship_id: str, request: Request):
|
||||
"""Soft-delete a competitor relationship (set active=False), preserve row."""
|
||||
pool = _get_pool(request)
|
||||
|
||||
row = await pool.fetchrow(
|
||||
"""UPDATE competitor_relationships
|
||||
SET active = FALSE, updated_at = NOW()
|
||||
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2) AND active = TRUE
|
||||
RETURNING id""",
|
||||
relationship_id, company_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(404, "Active competitor relationship not found")
|
||||
|
||||
await record_audit_event(
|
||||
pool,
|
||||
event_type="competitor_relationship.deleted",
|
||||
entity_type="competitor_relationship",
|
||||
entity_id=str(row["id"]),
|
||||
data={"company_id": company_id, "soft_deleted": True},
|
||||
actor="operator",
|
||||
)
|
||||
|
||||
return {"status": "deleted", "id": str(row["id"])}
|
||||
@@ -0,0 +1,183 @@
|
||||
"""Exposure Profile management endpoints for the Symbol Registry API."""
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
import asyncpg
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# --- Valid values ---
|
||||
VALID_MARKET_POSITION_TIERS = {"global_leader", "multinational", "regional", "domestic"}
|
||||
VALID_SOURCES = {"manual", "inferred"}
|
||||
|
||||
|
||||
# --- Request/Response Models ---
|
||||
|
||||
class ExposureProfileCreate(BaseModel):
|
||||
"""Request body for creating/updating an exposure profile."""
|
||||
geographic_revenue_mix: dict[str, float] = Field(default_factory=dict)
|
||||
supply_chain_regions: List[str] = Field(default_factory=list)
|
||||
key_input_commodities: List[str] = Field(default_factory=list)
|
||||
regulatory_jurisdictions: List[str] = Field(default_factory=list)
|
||||
market_position_tier: str = "regional"
|
||||
export_dependency_pct: float = 0.0
|
||||
source: str = "manual"
|
||||
confidence: float = 1.0
|
||||
|
||||
@field_validator("market_position_tier")
|
||||
@classmethod
|
||||
def validate_tier(cls, v: str) -> str:
|
||||
if v not in VALID_MARKET_POSITION_TIERS:
|
||||
raise ValueError(f"market_position_tier must be one of {VALID_MARKET_POSITION_TIERS}")
|
||||
return v
|
||||
|
||||
@field_validator("source")
|
||||
@classmethod
|
||||
def validate_source(cls, v: str) -> str:
|
||||
if v not in VALID_SOURCES:
|
||||
raise ValueError(f"source must be one of {VALID_SOURCES}")
|
||||
return v
|
||||
|
||||
@field_validator("export_dependency_pct", "confidence")
|
||||
@classmethod
|
||||
def validate_pct(cls, v: float) -> float:
|
||||
if not 0.0 <= v <= 1.0:
|
||||
raise ValueError("Value must be between 0.0 and 1.0")
|
||||
return v
|
||||
|
||||
|
||||
class ExposureProfileResponse(BaseModel):
|
||||
"""Response model for an exposure profile."""
|
||||
id: str
|
||||
company_id: str
|
||||
geographic_revenue_mix: dict[str, float]
|
||||
supply_chain_regions: List[str]
|
||||
key_input_commodities: List[str]
|
||||
regulatory_jurisdictions: List[str]
|
||||
market_position_tier: str
|
||||
export_dependency_pct: float
|
||||
source: str
|
||||
confidence: float
|
||||
version: int
|
||||
active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
def _row_to_profile(row: asyncpg.Record) -> dict[str, Any]:
|
||||
"""Convert an asyncpg Record to a profile response dict."""
|
||||
d = dict(row)
|
||||
for k, v in d.items():
|
||||
if isinstance(v, uuid.UUID):
|
||||
d[k] = str(v)
|
||||
# geographic_revenue_mix is stored as JSONB string, parse if needed
|
||||
if isinstance(d.get("geographic_revenue_mix"), str):
|
||||
d["geographic_revenue_mix"] = json.loads(d["geographic_revenue_mix"])
|
||||
return d
|
||||
|
||||
|
||||
def _get_pool(request: Request) -> asyncpg.Pool:
|
||||
"""Get the database pool from the app module."""
|
||||
from services.symbol_registry.app import pool
|
||||
return pool
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
|
||||
async def get_exposure_profile(company_id: str, request: Request):
|
||||
"""Get the current active exposure profile for a company."""
|
||||
pool = _get_pool(request)
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version, active,
|
||||
created_at, updated_at
|
||||
FROM exposure_profiles
|
||||
WHERE company_id = $1 AND active = TRUE
|
||||
ORDER BY version DESC
|
||||
LIMIT 1""",
|
||||
company_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(404, "No active exposure profile found for this company")
|
||||
return _row_to_profile(row)
|
||||
|
||||
|
||||
@router.put("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
|
||||
async def upsert_exposure_profile(company_id: str, body: ExposureProfileCreate, request: Request):
|
||||
"""Create or update an exposure profile. Archives the previous active version."""
|
||||
pool = _get_pool(request)
|
||||
|
||||
# Verify company exists
|
||||
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
|
||||
if not exists:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
# Fetch current active profile to get latest version
|
||||
current = await conn.fetchrow(
|
||||
"""SELECT version FROM exposure_profiles
|
||||
WHERE company_id = $1 AND active = TRUE
|
||||
ORDER BY version DESC LIMIT 1""",
|
||||
company_id,
|
||||
)
|
||||
|
||||
if current:
|
||||
new_version = current["version"] + 1
|
||||
# Archive the current active profile
|
||||
await conn.execute(
|
||||
"""UPDATE exposure_profiles
|
||||
SET active = FALSE, updated_at = NOW()
|
||||
WHERE company_id = $1 AND active = TRUE""",
|
||||
company_id,
|
||||
)
|
||||
else:
|
||||
new_version = 1
|
||||
|
||||
# Insert new profile
|
||||
row = await conn.fetchrow(
|
||||
"""INSERT INTO exposure_profiles
|
||||
(company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version, active)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
|
||||
RETURNING id, company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version, active,
|
||||
created_at, updated_at""",
|
||||
company_id,
|
||||
json.dumps(body.geographic_revenue_mix),
|
||||
body.supply_chain_regions,
|
||||
body.key_input_commodities,
|
||||
body.regulatory_jurisdictions,
|
||||
body.market_position_tier,
|
||||
body.export_dependency_pct,
|
||||
body.source,
|
||||
body.confidence,
|
||||
new_version,
|
||||
)
|
||||
|
||||
return _row_to_profile(row)
|
||||
|
||||
|
||||
@router.get("/companies/{company_id}/exposure/history", response_model=List[ExposureProfileResponse])
|
||||
async def get_exposure_history(company_id: str, request: Request):
|
||||
"""Get all exposure profile versions for a company, ordered by version descending."""
|
||||
pool = _get_pool(request)
|
||||
rows = await pool.fetch(
|
||||
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version, active,
|
||||
created_at, updated_at
|
||||
FROM exposure_profiles
|
||||
WHERE company_id = $1
|
||||
ORDER BY version DESC""",
|
||||
company_id,
|
||||
)
|
||||
return [_row_to_profile(r) for r in rows]
|
||||
Reference in New Issue
Block a user