"""Signal cluster classification and within-cluster correlation penalty. Groups signals into four clusters — momentum, structure, volatility, fundamentals — and applies exponential decay within each cluster to prevent likelihood ratio stacking inflation in the Bayesian pipeline. Within a cluster the strongest signal (by ``|log_lr|``) contributes at full weight; subsequent signals contribute at ``0.5^(n-1)`` decay. Signals in different clusters are treated as independent (no penalty). Single-signal clusters receive no penalty. Requirements: 7.1, 7.2, 7.3, 7.4 """ from __future__ import annotations import logging from collections import defaultdict from enum import Enum from services.signal_engine.models import LikelihoodRatio logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Signal cluster enum # --------------------------------------------------------------------------- class SignalCluster(str, Enum): """Correlation cluster for grouping related signals.""" MOMENTUM = "momentum" # MA stack, RSI STRUCTURE = "structure" # Fibonacci, Elliott Wave, Cup & Handle VOLATILITY = "volatility" # ATR-based, Bollinger-derived FUNDAMENTALS = "fundamentals" # valuation, earnings, macro # --------------------------------------------------------------------------- # Signal type → cluster mapping # --------------------------------------------------------------------------- _SIGNAL_CLUSTER_MAP: dict[str, SignalCluster] = { # Momentum "ma_stack": SignalCluster.MOMENTUM, "rsi": SignalCluster.MOMENTUM, # Structure "fibonacci": SignalCluster.STRUCTURE, "elliott_wave": SignalCluster.STRUCTURE, "cup_handle": SignalCluster.STRUCTURE, # Volatility "atr": SignalCluster.VOLATILITY, "bollinger": SignalCluster.VOLATILITY, # Fundamentals "valuation": SignalCluster.FUNDAMENTALS, "earnings": SignalCluster.FUNDAMENTALS, "macro": SignalCluster.FUNDAMENTALS, } # Decay factor applied to successive signals within the same cluster. _WITHIN_CLUSTER_DECAY = 0.5 # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def classify_signal(signal_type: str) -> SignalCluster: """Map a signal type string to its correlation cluster. Falls back to :pyattr:`SignalCluster.FUNDAMENTALS` for unknown signal types so that unrecognised signals still participate in the penalty system rather than silently bypassing it. """ cluster = _SIGNAL_CLUSTER_MAP.get(signal_type) if cluster is None: logger.warning( "Unknown signal type %r — defaulting to FUNDAMENTALS cluster", signal_type, ) return SignalCluster.FUNDAMENTALS return cluster def apply_correlation_penalty( likelihood_ratios: list[LikelihoodRatio], ) -> list[LikelihoodRatio]: """Apply within-cluster decay penalty to correlated signals. Algorithm: 1. Group LRs by cluster. 2. Within each cluster, sort by ``abs(log_lr)`` descending (strongest first). 3. The strongest signal keeps its full ``log_lr`` as ``penalized_log_lr``. 4. The *n*-th signal (1-indexed) receives ``penalized_log_lr = log_lr * 0.5^(n-1)``. 5. Single-signal clusters are untouched (``penalized_log_lr = log_lr``). 6. Cross-cluster signals are independent — no penalty applied across clusters. Returns a **new** list of :class:`LikelihoodRatio` instances with updated ``penalized_log_lr`` values. The original objects are not mutated. """ if not likelihood_ratios: return [] # Group by cluster clusters: dict[str, list[tuple[int, LikelihoodRatio]]] = defaultdict(list) for idx, lr in enumerate(likelihood_ratios): clusters[lr.cluster].append((idx, lr)) # Build result list preserving original order result: list[LikelihoodRatio | None] = [None] * len(likelihood_ratios) for cluster_name, members in clusters.items(): # Sort by abs(log_lr) descending — strongest first sorted_members = sorted(members, key=lambda t: abs(t[1].log_lr), reverse=True) for rank, (orig_idx, lr) in enumerate(sorted_members): decay = _WITHIN_CLUSTER_DECAY ** rank # 0.5^0=1, 0.5^1=0.5, ... penalized = lr.log_lr * decay result[orig_idx] = LikelihoodRatio( signal_type=lr.signal_type, cluster=lr.cluster, lr=lr.lr, log_lr=lr.log_lr, penalized_log_lr=penalized, hit_rate=lr.hit_rate, strength=lr.strength, ) # Safety: should never happen, but guard against it return [r for r in result if r is not None]