"""Sector and market-level rollup aggregation. Aggregates company-level trend summaries into sector and market-level summaries, enabling top-down views of sentiment and risk across the portfolio. Requirements: 6.1, 6.2, 6.3, 6.4, 6.5 """ from __future__ import annotations import json import logging from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone import asyncpg from services.shared.schemas import ( DisagreementDetail, TrendDirection, TrendSummary, TrendWindow, ) logger = logging.getLogger(__name__) @dataclass class CompanyTrendRow: """A company-level trend summary fetched from the DB for rollup.""" entity_id: str # ticker sector: str window: str trend_direction: str trend_strength: float confidence: float contradiction_score: float dominant_catalysts: list[str] material_risks: list[str] top_supporting_evidence: list[str] top_opposing_evidence: list[str] @dataclass class SectorMacroImpact: """Aggregated macro impact data for a single sector. Used to incorporate macro signals into sector and market rollups. Requirements: 6.1, 6.2, 6.3 """ sector: str total_impact: float # sum of macro_impact_score across companies in sector avg_impact: float # average macro_impact_score company_count: int # number of companies affected net_direction: float # weighted direction: +1 positive, -1 negative, 0 mixed event_ids: list[str] = field(default_factory=list) # contributing event IDs # Threshold for disproportionate sector impact (Requirement 6.3) SECTOR_CONCENTRATION_THRESHOLD = 0.60 # --------------------------------------------------------------------------- # Fetch sector-level macro impact aggregates # --------------------------------------------------------------------------- _SECTOR_MACRO_IMPACT_QUERY = """ SELECT c.sector, mir.event_id, mir.macro_impact_score, mir.impact_direction FROM macro_impact_records mir JOIN companies c ON c.id = mir.company_id AND c.active = TRUE WHERE mir.computed_at >= $1 AND mir.computed_at <= $2 ORDER BY c.sector, mir.macro_impact_score DESC """ async def fetch_sector_macro_impacts( pool: asyncpg.Pool, window_start: datetime, window_end: datetime, ) -> dict[str, SectorMacroImpact]: """Fetch macro impact records aggregated by sector for a time range. Returns a mapping of sector name to SectorMacroImpact. """ rows = await pool.fetch(_SECTOR_MACRO_IMPACT_QUERY, window_start, window_end) # Accumulate per-sector sector_data: dict[str, dict] = {} direction_map = {"positive": 1.0, "negative": -1.0, "mixed": 0.0, "neutral": 0.0} for row in rows: sector = str(row["sector"]) if row["sector"] else "Unknown" score = float(row["macro_impact_score"] or 0.0) direction = row["impact_direction"] or "neutral" event_id = str(row["event_id"]) if sector not in sector_data: sector_data[sector] = { "total": 0.0, "count": 0, "dir_sum": 0.0, "dir_count": 0, "event_ids": set(), } d = sector_data[sector] d["total"] += score d["count"] += 1 dir_val = direction_map.get(direction, 0.0) if dir_val != 0.0: d["dir_sum"] += dir_val d["dir_count"] += 1 d["event_ids"].add(event_id) result: dict[str, SectorMacroImpact] = {} for sector, d in sector_data.items(): count = d["count"] avg = d["total"] / count if count > 0 else 0.0 net_dir = d["dir_sum"] / d["dir_count"] if d["dir_count"] > 0 else 0.0 result[sector] = SectorMacroImpact( sector=sector, total_impact=d["total"], avg_impact=avg, company_count=count, net_direction=net_dir, event_ids=sorted(d["event_ids"]), ) return result # --------------------------------------------------------------------------- # Sector macro concentration helper (Requirement 6.3) # --------------------------------------------------------------------------- def compute_sector_macro_concentration( sector_impacts: dict[str, SectorMacroImpact], ) -> list[tuple[str, float]]: """Compute the fraction of total macro impact concentrated in each sector. Returns a list of (sector, fraction) tuples sorted by fraction descending. Sectors with fraction > SECTOR_CONCENTRATION_THRESHOLD are considered disproportionately affected. """ total = sum(si.total_impact for si in sector_impacts.values()) if total <= 0.0: return [] fractions = [ (sector, si.total_impact / total) for sector, si in sector_impacts.items() ] fractions.sort(key=lambda x: x[1], reverse=True) return fractions # --------------------------------------------------------------------------- # Fetch latest company trends for a given window # --------------------------------------------------------------------------- _LATEST_COMPANY_TRENDS_QUERY = """ SELECT DISTINCT ON (tw.entity_id) tw.entity_id, c.sector, tw.window, tw.trend_direction, tw.trend_strength, tw.confidence, tw.contradiction_score, tw.dominant_catalysts, tw.material_risks, tw.top_supporting_evidence, tw.top_opposing_evidence FROM trend_windows tw JOIN companies c ON c.ticker = tw.entity_id AND c.active = TRUE WHERE tw.entity_type = 'company' AND tw.window = $1 AND tw.generated_at >= $2 ORDER BY tw.entity_id, tw.generated_at DESC """ def _parse_jsonb_list(val: object) -> list[str]: """Safely parse a JSONB column that should be a list of strings.""" if isinstance(val, list): return [str(v) for v in val] if isinstance(val, str): parsed = json.loads(val) if isinstance(parsed, list): return [str(v) for v in parsed] return [] def _parse_company_trend_row(row: object) -> CompanyTrendRow: """Convert an asyncpg Record to a CompanyTrendRow.""" # asyncpg Records support dict() but aren't typed; use getattr-style access get = getattr(row, "__getitem__", None) if get is None: raise TypeError(f"Expected a mapping-like row, got {type(row)}") def _str(key: str, default: str = "") -> str: val = get(key) return str(val) if val is not None else default def _float(key: str) -> float: val = get(key) return float(val) if val is not None else 0.0 return CompanyTrendRow( entity_id=_str("entity_id"), sector=_str("sector", "Unknown") or "Unknown", window=_str("window"), trend_direction=_str("trend_direction"), trend_strength=_float("trend_strength"), confidence=_float("confidence"), contradiction_score=_float("contradiction_score"), dominant_catalysts=_parse_jsonb_list(get("dominant_catalysts")), material_risks=_parse_jsonb_list(get("material_risks")), top_supporting_evidence=_parse_jsonb_list(get("top_supporting_evidence")), top_opposing_evidence=_parse_jsonb_list(get("top_opposing_evidence")), ) async def fetch_latest_company_trends( pool: asyncpg.Pool, window: str, since: datetime, ) -> list[CompanyTrendRow]: """Fetch the most recent company-level trend for each ticker in a window.""" rows = await pool.fetch(_LATEST_COMPANY_TRENDS_QUERY, window, since) return [_parse_company_trend_row(r) for r in rows] # --------------------------------------------------------------------------- # Pure rollup logic # --------------------------------------------------------------------------- # Direction mapping for numeric aggregation _DIRECTION_VALUES = { TrendDirection.BULLISH.value: 1.0, TrendDirection.BEARISH.value: -1.0, TrendDirection.MIXED.value: 0.0, TrendDirection.NEUTRAL.value: 0.0, } BULLISH_THRESHOLD = 0.15 BEARISH_THRESHOLD = -0.15 def rollup_trends( trends: list[CompanyTrendRow], entity_type: str, entity_id: str, window: str, reference_time: datetime, macro_impacts: dict[str, SectorMacroImpact] | None = None, ) -> TrendSummary: """Aggregate a list of company-level trends into a single rollup summary. Each company trend is weighted by its confidence to produce a confidence-weighted average of direction, strength, and contradiction. When macro_impacts is provided: - For sector rollups: incorporates the sector's macro signal into strength and confidence, weighted by constituent company exposure. - For market rollups: aggregates macro signals across all sectors and surfaces disproportionately affected sectors (>60% concentration) in material_risks or dominant_catalysts. When macro_impacts is None or empty, produces identical output to the original company-only rollup. """ if not trends: return TrendSummary( entity_type=entity_type, entity_id=entity_id, window=TrendWindow(window), trend_direction=TrendDirection.NEUTRAL, trend_strength=0.0, confidence=0.0, generated_at=reference_time, ) total_weight = 0.0 weighted_direction = 0.0 weighted_strength = 0.0 weighted_contradiction = 0.0 catalyst_weights: dict[str, float] = {} risk_set: dict[str, float] = {} all_supporting: list[str] = [] all_opposing: list[str] = [] for t in trends: w = t.confidence total_weight += w dir_val = _DIRECTION_VALUES.get(t.trend_direction, 0.0) weighted_direction += w * dir_val weighted_strength += w * t.trend_strength weighted_contradiction += w * t.contradiction_score for cat in t.dominant_catalysts: catalyst_weights[cat] = catalyst_weights.get(cat, 0.0) + w for risk in t.material_risks: norm = risk.strip().lower() if norm not in risk_set: risk_set[norm] = w else: risk_set[norm] = max(risk_set[norm], w) all_supporting.extend(t.top_supporting_evidence) all_opposing.extend(t.top_opposing_evidence) if total_weight == 0.0: return TrendSummary( entity_type=entity_type, entity_id=entity_id, window=TrendWindow(window), trend_direction=TrendDirection.NEUTRAL, trend_strength=0.0, confidence=0.0, generated_at=reference_time, ) avg_direction = weighted_direction / total_weight avg_strength = weighted_strength / total_weight avg_contradiction = weighted_contradiction / total_weight avg_confidence = total_weight / len(trends) # --- Incorporate macro impact signals when available --- macro_strength_adj = 0.0 macro_confidence_adj = 0.0 macro_catalysts: list[str] = [] macro_risks: list[str] = [] if macro_impacts: if entity_type == "sector": # Sector rollup: incorporate this sector's macro signal sector_macro = macro_impacts.get(entity_id) if sector_macro and sector_macro.total_impact > 0: # Weight macro contribution by avg impact and company breadth breadth = min(sector_macro.company_count / max(len(trends), 1), 1.0) macro_strength_adj = sector_macro.avg_impact * breadth * 0.3 macro_confidence_adj = sector_macro.avg_impact * breadth * 0.1 # Nudge direction based on macro net direction avg_direction += sector_macro.net_direction * macro_strength_adj * 0.5 elif entity_type == "market": # Market rollup: aggregate macro signals across all sectors total_macro = sum(si.total_impact for si in macro_impacts.values()) if total_macro > 0: total_companies = sum(si.company_count for si in macro_impacts.values()) breadth = min(total_companies / max(len(trends), 1), 1.0) avg_macro = total_macro / max(len(macro_impacts), 1) macro_strength_adj = avg_macro * breadth * 0.3 macro_confidence_adj = avg_macro * breadth * 0.1 # Aggregate net direction across sectors dir_sum = sum( si.net_direction * si.total_impact for si in macro_impacts.values() ) net_dir = dir_sum / total_macro if total_macro > 0 else 0.0 avg_direction += net_dir * macro_strength_adj * 0.5 # Surface disproportionately affected sectors (Requirement 6.3) concentration = compute_sector_macro_concentration(macro_impacts) for sector, fraction in concentration: if fraction > SECTOR_CONCENTRATION_THRESHOLD: si = macro_impacts[sector] label = f"Macro: {sector} ({fraction:.0%} of macro impact)" if si.net_direction < 0: macro_risks.append(label) else: macro_catalysts.append(label) # Apply macro adjustments to strength and confidence adj_strength = avg_strength + macro_strength_adj adj_confidence = avg_confidence + macro_confidence_adj # Derive direction direction = _derive_rollup_direction(avg_direction, avg_contradiction) # Top catalysts (macro catalysts prepended when present) sorted_catalysts = sorted(catalyst_weights.items(), key=lambda x: x[1], reverse=True) catalysts = macro_catalysts + [c for c, _ in sorted_catalysts[:5]] catalysts = catalysts[:5] # Top risks (macro risks prepended when present, deduplicated) sorted_risks = sorted(risk_set.items(), key=lambda x: x[1], reverse=True) base_risks = [r for r, _ in sorted_risks[:5]] risks = macro_risks + base_risks risks = risks[:5] # Disagreement details disagreement = _build_rollup_disagreement(trends, entity_id) return TrendSummary( entity_type=entity_type, entity_id=entity_id, window=TrendWindow(window), trend_direction=direction, trend_strength=round(min(abs(adj_strength), 1.0), 4), confidence=round(max(0.0, min(adj_confidence, 1.0)), 4), top_supporting_evidence=list(dict.fromkeys(all_supporting))[:10], top_opposing_evidence=list(dict.fromkeys(all_opposing))[:10], dominant_catalysts=catalysts, material_risks=risks, contradiction_score=round(max(0.0, min(avg_contradiction, 1.0)), 4), disagreement_details=disagreement, generated_at=reference_time, ) def _derive_rollup_direction( avg_direction: float, avg_contradiction: float, ) -> TrendDirection: """Map averaged direction value to a TrendDirection.""" if avg_contradiction > 0.10 and abs(avg_direction) < 0.3: return TrendDirection.MIXED if avg_direction >= BULLISH_THRESHOLD: return TrendDirection.BULLISH if avg_direction <= BEARISH_THRESHOLD: return TrendDirection.BEARISH return TrendDirection.NEUTRAL def _build_rollup_disagreement( trends: list[CompanyTrendRow], entity_id: str, ) -> list[DisagreementDetail]: """Build disagreement details showing which companies are bullish vs bearish.""" bullish_ids: list[str] = [] bearish_ids: list[str] = [] bullish_weight = 0.0 bearish_weight = 0.0 for t in trends: if t.trend_direction == TrendDirection.BULLISH.value: bullish_ids.append(t.entity_id) bullish_weight += t.confidence elif t.trend_direction == TrendDirection.BEARISH.value: bearish_ids.append(t.entity_id) bearish_weight += t.confidence if not bullish_ids or not bearish_ids: return [] return [ DisagreementDetail( dimension="company_direction", positive_doc_ids=bullish_ids, negative_doc_ids=bearish_ids, positive_weight=round(bullish_weight, 4), negative_weight=round(bearish_weight, 4), description=( f"{entity_id}: {len(bullish_ids)} bullish vs " f"{len(bearish_ids)} bearish companies" ), ) ] # --------------------------------------------------------------------------- # Persist rollup (reuses the same trend_windows table) # --------------------------------------------------------------------------- _UPSERT_TREND = """ INSERT INTO trend_windows ( entity_type, entity_id, window, trend_direction, trend_strength, confidence, top_supporting_evidence, top_opposing_evidence, dominant_catalysts, material_risks, contradiction_score, disagreement_details, market_context, generated_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7::jsonb, $8::jsonb, $9::jsonb, $10::jsonb, $11, $12::jsonb, $13::jsonb, $14 ) RETURNING id """ async def persist_rollup( pool: asyncpg.Pool, summary: TrendSummary, ) -> str: """Insert a rollup trend summary and return its UUID.""" row = await pool.fetchrow( _UPSERT_TREND, summary.entity_type, summary.entity_id, summary.window.value, summary.trend_direction.value, summary.trend_strength, summary.confidence, json.dumps(summary.top_supporting_evidence), json.dumps(summary.top_opposing_evidence), json.dumps(summary.dominant_catalysts), json.dumps(summary.material_risks), summary.contradiction_score, json.dumps([d.model_dump() for d in summary.disagreement_details]), json.dumps({}), summary.generated_at, ) return str(row["id"]) # type: ignore[index] # --------------------------------------------------------------------------- # High-level rollup entry points # --------------------------------------------------------------------------- async def aggregate_sector( pool: asyncpg.Pool, sector: str, window: str, reference_time: datetime | None = None, since: datetime | None = None, macro_impacts: dict[str, SectorMacroImpact] | None = None, ) -> TrendSummary: """Compute and persist a sector-level rollup for one window. Fetches the latest company trends, filters to the given sector, and rolls them up into a single sector summary. When macro_impacts is provided, incorporates macro signals weighted by constituent company exposure. """ if reference_time is None: reference_time = datetime.now(timezone.utc) if since is None: since = reference_time - _window_lookback(window) all_trends = await fetch_latest_company_trends(pool, window, since) sector_trends = [t for t in all_trends if t.sector == sector] # Fetch macro impacts if not provided if macro_impacts is None: macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time) summary = rollup_trends( sector_trends, "sector", sector, window, reference_time, macro_impacts=macro_impacts, ) if sector_trends: rollup_id = await persist_rollup(pool, summary) logger.info( "Persisted sector rollup %s for %s/%s: direction=%s strength=%.3f companies=%d", rollup_id, sector, window, summary.trend_direction.value, summary.trend_strength, len(sector_trends), ) return summary async def aggregate_market( pool: asyncpg.Pool, window: str, reference_time: datetime | None = None, since: datetime | None = None, macro_impacts: dict[str, SectorMacroImpact] | None = None, ) -> TrendSummary: """Compute and persist a market-wide rollup for one window. Aggregates all company trends regardless of sector. When macro_impacts is provided, aggregates macro signals across all sectors and surfaces disproportionately affected sectors in material_risks or dominant_catalysts. """ if reference_time is None: reference_time = datetime.now(timezone.utc) if since is None: since = reference_time - _window_lookback(window) all_trends = await fetch_latest_company_trends(pool, window, since) # Fetch macro impacts if not provided if macro_impacts is None: macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time) summary = rollup_trends( all_trends, "market", "all", window, reference_time, macro_impacts=macro_impacts, ) if all_trends: rollup_id = await persist_rollup(pool, summary) logger.info( "Persisted market rollup %s for %s: direction=%s strength=%.3f companies=%d", rollup_id, window, summary.trend_direction.value, summary.trend_strength, len(all_trends), ) return summary async def aggregate_all_sectors( pool: asyncpg.Pool, window: str, reference_time: datetime | None = None, since: datetime | None = None, macro_impacts: dict[str, SectorMacroImpact] | None = None, ) -> list[TrendSummary]: """Compute sector rollups for every sector that has company trends.""" if reference_time is None: reference_time = datetime.now(timezone.utc) if since is None: since = reference_time - _window_lookback(window) all_trends = await fetch_latest_company_trends(pool, window, since) # Fetch macro impacts once for all sectors if not provided if macro_impacts is None: macro_impacts = await fetch_sector_macro_impacts(pool, since, reference_time) # Group by sector sectors: dict[str, list[CompanyTrendRow]] = {} for t in all_trends: sectors.setdefault(t.sector, []).append(t) summaries: list[TrendSummary] = [] for sector, trends in sectors.items(): summary = rollup_trends( trends, "sector", sector, window, reference_time, macro_impacts=macro_impacts, ) if trends: _id = await persist_rollup(pool, summary) summaries.append(summary) return summaries def _window_lookback(window: str) -> timedelta: """Return a reasonable lookback for finding recent company trends.""" mapping = { TrendWindow.INTRADAY.value: timedelta(hours=24), TrendWindow.ONE_DAY.value: timedelta(days=2), TrendWindow.SEVEN_DAY.value: timedelta(days=8), TrendWindow.THIRTY_DAY.value: timedelta(days=35), TrendWindow.NINETY_DAY.value: timedelta(days=95), } return mapping.get(window, timedelta(days=8))