"""Risk engine - portfolio and account risk configuration and enforcement. Defines the configuration and state models used to enforce guardrails on trade execution: max position size, sector exposure, daily loss limits, news-shock lockouts, and operator approval rules. Also implements the hard-block evaluation logic that decides whether a proposed order is allowed before it reaches the broker adapter. Requirements: 8.1, 8.2, 8.3, 8.4, 8.5 Design: Section 4.8 - Risk Engine """ from __future__ import annotations import uuid from datetime import datetime, timedelta, timezone from enum import Enum from typing import Any from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Enums # --------------------------------------------------------------------------- class TradingMode(str, Enum): """Execution environment separation (Requirement 8.1).""" PAPER = "paper" LIVE = "live" DISABLED = "disabled" class RiskCheckResult(str, Enum): """Outcome of a single risk check.""" PASS = "pass" FAIL = "fail" WARN = "warn" # --------------------------------------------------------------------------- # Portfolio-level risk configuration (Requirement 8.2, 8.4) # --------------------------------------------------------------------------- class PositionLimits(BaseModel): """Per-position size constraints.""" max_position_pct: float = Field( default=0.05, ge=0, le=1, description="Maximum portfolio percentage for a single position", ) max_position_value: float = Field( default=10_000.0, ge=0, description="Maximum dollar value for a single position", ) max_shares_per_order: float = Field( default=1000.0, ge=0, description="Maximum shares in a single order", ) class SectorExposureLimits(BaseModel): """Sector-level concentration limits.""" max_sector_pct: float = Field( default=0.25, ge=0, le=1, description="Maximum portfolio percentage exposed to one sector", ) max_sectors: int = Field( default=10, ge=1, description="Maximum number of sectors with open positions", ) class DailyLossLimits(BaseModel): """Daily drawdown controls.""" max_daily_loss_pct: float = Field( default=0.02, ge=0, le=1, description="Maximum portfolio loss percentage in a single day before halting", ) max_daily_loss_value: float = Field( default=1_000.0, ge=0, description="Maximum dollar loss in a single day before halting", ) max_daily_trades: int = Field( default=20, ge=0, description="Maximum number of trades per day", ) class NewsShockLockout(BaseModel): """News-shock lockout configuration. When a symbol has a high-impact news event, trading is paused for a configurable cooldown period. """ enabled: bool = True lockout_minutes: int = Field( default=60, ge=0, description="Minutes to lock out trading after a high-impact news event", ) impact_threshold: float = Field( default=0.80, ge=0, le=1, description="Minimum impact_score from document intelligence to trigger lockout", ) catalyst_types: list[str] = Field( default_factory=lambda: ["earnings", "legal", "m_and_a"], description="Catalyst types that trigger lockout when above threshold", ) class OperatorApproval(BaseModel): """Operator approval workflow for live trading (Requirement 8.2).""" require_approval_for_live: bool = Field( default=True, description="Whether live orders require operator approval", ) auto_approve_paper: bool = Field( default=True, description="Whether paper orders are auto-approved", ) approval_timeout_minutes: int = Field( default=30, ge=1, description="Minutes before a pending approval expires", ) class SymbolCooldown(BaseModel): """Per-symbol cooldown after a trade.""" cooldown_minutes: int = Field( default=15, ge=0, description="Minutes to wait before trading the same symbol again", ) max_open_positions_per_symbol: int = Field( default=1, ge=1, description="Maximum concurrent open positions for a single symbol", ) class PortfolioRiskConfig(BaseModel): """Complete portfolio-level risk configuration. This is the top-level config that governs all risk checks. Persisted in PostgreSQL and loaded at engine startup. """ config_id: str = Field(default_factory=lambda: str(uuid.uuid4())) name: str = "default" trading_mode: TradingMode = TradingMode.PAPER position_limits: PositionLimits = Field(default_factory=PositionLimits) sector_exposure: SectorExposureLimits = Field(default_factory=SectorExposureLimits) daily_loss: DailyLossLimits = Field(default_factory=DailyLossLimits) news_shock: NewsShockLockout = Field(default_factory=NewsShockLockout) operator_approval: OperatorApproval = Field(default_factory=OperatorApproval) symbol_cooldown: SymbolCooldown = Field(default_factory=SymbolCooldown) active: bool = True created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) def to_db_json(self) -> dict[str, Any]: """Serialize the full config to a JSON-compatible dict for DB storage.""" return self.model_dump(mode="json") @classmethod def from_db_json(cls, data: dict[str, Any]) -> PortfolioRiskConfig: """Deserialize from a DB JSON column.""" return cls.model_validate(data) # --------------------------------------------------------------------------- # Account risk state (runtime snapshot) # --------------------------------------------------------------------------- class AccountRiskState(BaseModel): """Runtime snapshot of an account's risk posture. Computed from broker positions, today's trades, and current P&L. Used by risk checks to evaluate whether a new order is allowed. """ account_id: str = "" portfolio_value: float = 0.0 cash: float = 0.0 buying_power: float = 0.0 daily_pnl: float = 0.0 daily_trade_count: int = 0 open_position_count: int = 0 positions_by_symbol: dict[str, float] = Field( default_factory=dict, description="Map of ticker → current market value", ) positions_by_sector: dict[str, float] = Field( default_factory=dict, description="Map of sector → total market value", ) last_trade_times: dict[str, datetime] = Field( default_factory=dict, description="Map of ticker → last trade timestamp for cooldown checks", ) locked_symbols: dict[str, datetime] = Field( default_factory=dict, description="Map of ticker → lockout expiry for news-shock lockouts", ) snapshot_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # --------------------------------------------------------------------------- # Risk check output (Requirement 8.3 - full decision trace) # --------------------------------------------------------------------------- class RiskCheckDetail(BaseModel): """Result of a single risk check.""" check_name: str result: RiskCheckResult message: str = "" threshold: float | None = None actual: float | None = None class RiskEvaluation(BaseModel): """Complete risk evaluation for a proposed order. Captures every check performed so the full decision trace is reproducible (Requirement 8.3). """ evaluation_id: str = Field(default_factory=lambda: str(uuid.uuid4())) recommendation_id: str | None = None ticker: str = "" eligible: bool = False allowed_mode: TradingMode = TradingMode.DISABLED checks: list[RiskCheckDetail] = Field(default_factory=list) rejection_reasons: list[str] = Field(default_factory=list) config_snapshot: PortfolioRiskConfig | None = None state_snapshot: AccountRiskState | None = None evaluated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @property def passed(self) -> bool: return self.eligible and len(self.rejection_reasons) == 0 # --------------------------------------------------------------------------- # Default configuration # --------------------------------------------------------------------------- DEFAULT_RISK_CONFIG = PortfolioRiskConfig() # --------------------------------------------------------------------------- # Proposed order (input to risk evaluation) # --------------------------------------------------------------------------- class ProposedOrder(BaseModel): """A proposed order to be evaluated by the risk engine before submission. This is the input to evaluate_order(). It carries enough context for every risk check to run without external lookups. """ recommendation_id: str | None = None ticker: str sector: str = "" action: str = "buy" # buy | sell quantity: float = 0.0 estimated_value: float = 0.0 confidence: float = 0.0 # --------------------------------------------------------------------------- # Individual risk checks (Requirement 8.4) # --------------------------------------------------------------------------- def _check_trading_mode( config: PortfolioRiskConfig, ) -> RiskCheckDetail: """Block all orders when trading is disabled.""" if config.trading_mode == TradingMode.DISABLED: return RiskCheckDetail( check_name="trading_mode", result=RiskCheckResult.FAIL, message="Trading is disabled", ) return RiskCheckDetail( check_name="trading_mode", result=RiskCheckResult.PASS, message=f"Trading mode: {config.trading_mode.value}", ) def _check_max_position_size( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, ) -> list[RiskCheckDetail]: """Enforce per-position size limits (value, percentage, shares).""" checks: list[RiskCheckDetail] = [] limits = config.position_limits # Check max position value existing_value = state.positions_by_symbol.get(order.ticker, 0.0) if order.action == "sell": new_total_value = max(existing_value - order.estimated_value, 0.0) else: new_total_value = existing_value + order.estimated_value # Sell orders always pass position value check — they reduce exposure if order.action == "sell": value_result = RiskCheckResult.PASS value_verb = "within (sell reduces exposure)" elif new_total_value <= limits.max_position_value: value_result = RiskCheckResult.PASS value_verb = "within" else: value_result = RiskCheckResult.FAIL value_verb = "exceeds" checks.append(RiskCheckDetail( check_name="max_position_value", result=value_result, message=( f"Position value {new_total_value:.2f} " f"{value_verb} " f"limit {limits.max_position_value:.2f}" ), threshold=limits.max_position_value, actual=new_total_value, )) # Check max position percentage of portfolio if state.portfolio_value > 0: position_pct = new_total_value / state.portfolio_value else: position_pct = 1.0 if new_total_value > 0 else 0.0 # Sell orders that reduce concentration should always pass — blocking a # sell on an over-concentrated position prevents the user from fixing it. if order.action == "sell": pct_result = RiskCheckResult.PASS pct_verb = "within (sell reduces exposure)" elif position_pct <= limits.max_position_pct: pct_result = RiskCheckResult.PASS pct_verb = "within" else: pct_result = RiskCheckResult.FAIL pct_verb = "exceeds" checks.append(RiskCheckDetail( check_name="max_position_pct", result=pct_result, message=( f"Position {position_pct:.4f} of portfolio " f"{pct_verb} " f"limit {limits.max_position_pct:.4f}" ), threshold=limits.max_position_pct, actual=position_pct, )) # Check max shares per order checks.append(RiskCheckDetail( check_name="max_shares_per_order", result=( RiskCheckResult.PASS if order.quantity <= limits.max_shares_per_order else RiskCheckResult.FAIL ), message=( f"Order quantity {order.quantity:.0f} " f"{'within' if order.quantity <= limits.max_shares_per_order else 'exceeds'} " f"limit {limits.max_shares_per_order:.0f}" ), threshold=limits.max_shares_per_order, actual=order.quantity, )) return checks def _check_sector_exposure( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, ) -> RiskCheckDetail: """Enforce sector concentration limits.""" limits = config.sector_exposure if not order.sector: return RiskCheckDetail( check_name="sector_exposure", result=RiskCheckResult.WARN, message="No sector provided on order; skipping sector check", ) existing_sector_value = state.positions_by_sector.get(order.sector, 0.0) new_sector_value = existing_sector_value + order.estimated_value if state.portfolio_value > 0: sector_pct = new_sector_value / state.portfolio_value else: sector_pct = 1.0 if new_sector_value > 0 else 0.0 return RiskCheckDetail( check_name="sector_exposure", result=( RiskCheckResult.PASS if sector_pct <= limits.max_sector_pct else RiskCheckResult.FAIL ), message=( f"Sector '{order.sector}' exposure {sector_pct:.4f} " f"{'within' if sector_pct <= limits.max_sector_pct else 'exceeds'} " f"limit {limits.max_sector_pct:.4f}" ), threshold=limits.max_sector_pct, actual=sector_pct, ) def _check_daily_loss( config: PortfolioRiskConfig, state: AccountRiskState, ) -> list[RiskCheckDetail]: """Enforce daily loss and trade count limits.""" checks: list[RiskCheckDetail] = [] limits = config.daily_loss # Daily loss percentage if state.portfolio_value > 0: loss_pct = abs(min(state.daily_pnl, 0.0)) / state.portfolio_value else: loss_pct = 0.0 checks.append(RiskCheckDetail( check_name="daily_loss_pct", result=( RiskCheckResult.PASS if loss_pct <= limits.max_daily_loss_pct else RiskCheckResult.FAIL ), message=( f"Daily loss {loss_pct:.4f} " f"{'within' if loss_pct <= limits.max_daily_loss_pct else 'exceeds'} " f"limit {limits.max_daily_loss_pct:.4f}" ), threshold=limits.max_daily_loss_pct, actual=loss_pct, )) # Daily loss absolute value abs_loss = abs(min(state.daily_pnl, 0.0)) checks.append(RiskCheckDetail( check_name="daily_loss_value", result=( RiskCheckResult.PASS if abs_loss <= limits.max_daily_loss_value else RiskCheckResult.FAIL ), message=( f"Daily loss ${abs_loss:.2f} " f"{'within' if abs_loss <= limits.max_daily_loss_value else 'exceeds'} " f"limit ${limits.max_daily_loss_value:.2f}" ), threshold=limits.max_daily_loss_value, actual=abs_loss, )) # Daily trade count checks.append(RiskCheckDetail( check_name="daily_trade_count", result=( RiskCheckResult.PASS if state.daily_trade_count < limits.max_daily_trades else RiskCheckResult.FAIL ), message=( f"Daily trades {state.daily_trade_count} " f"{'within' if state.daily_trade_count < limits.max_daily_trades else 'at/exceeds'} " f"limit {limits.max_daily_trades}" ), threshold=float(limits.max_daily_trades), actual=float(state.daily_trade_count), )) return checks def _check_news_shock_lockout( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, now: datetime | None = None, ) -> RiskCheckDetail: """Block trading on symbols under news-shock lockout.""" lockout_cfg = config.news_shock if not lockout_cfg.enabled: return RiskCheckDetail( check_name="news_shock_lockout", result=RiskCheckResult.PASS, message="News-shock lockout is disabled", ) now = now or datetime.now(timezone.utc) lockout_expiry = state.locked_symbols.get(order.ticker) if lockout_expiry is not None and now < lockout_expiry: remaining = lockout_expiry - now return RiskCheckDetail( check_name="news_shock_lockout", result=RiskCheckResult.FAIL, message=( f"Symbol {order.ticker} locked out until " f"{lockout_expiry.isoformat()} " f"({remaining.total_seconds():.0f}s remaining)" ), ) return RiskCheckDetail( check_name="news_shock_lockout", result=RiskCheckResult.PASS, message=f"No active lockout for {order.ticker}", ) def _check_symbol_cooldown( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, now: datetime | None = None, ) -> RiskCheckDetail: """Enforce per-symbol cooldown between trades.""" cooldown_cfg = config.symbol_cooldown now = now or datetime.now(timezone.utc) last_trade = state.last_trade_times.get(order.ticker) if last_trade is not None: cooldown_end = last_trade + timedelta(minutes=cooldown_cfg.cooldown_minutes) if now < cooldown_end: remaining = cooldown_end - now return RiskCheckDetail( check_name="symbol_cooldown", result=RiskCheckResult.FAIL, message=( f"Symbol {order.ticker} in cooldown until " f"{cooldown_end.isoformat()} " f"({remaining.total_seconds():.0f}s remaining)" ), ) return RiskCheckDetail( check_name="symbol_cooldown", result=RiskCheckResult.PASS, message=f"No active cooldown for {order.ticker}", ) # --------------------------------------------------------------------------- # Main evaluation entry point (Requirements 8.3, 8.4, 8.5) # --------------------------------------------------------------------------- def evaluate_order( order: ProposedOrder, config: PortfolioRiskConfig = DEFAULT_RISK_CONFIG, state: AccountRiskState | None = None, now: datetime | None = None, ) -> RiskEvaluation: """Evaluate a proposed order against all risk controls. Runs every hard-block check and returns a RiskEvaluation capturing the full decision trace (Requirement 8.3). If any check fails, the order is rejected before broker submission (Requirement 8.4). The engine fails closed: if state is missing or ambiguous, the order is rejected (Requirement 8.5). """ state = state or AccountRiskState() now = now or datetime.now(timezone.utc) all_checks: list[RiskCheckDetail] = [] rejection_reasons: list[str] = [] # 1. Trading mode gate mode_check = _check_trading_mode(config) all_checks.append(mode_check) if mode_check.result == RiskCheckResult.FAIL: rejection_reasons.append(mode_check.message) # 2. Position size limits position_checks = _check_max_position_size(order, config, state) all_checks.extend(position_checks) for c in position_checks: if c.result == RiskCheckResult.FAIL: rejection_reasons.append(c.message) # 3. Sector exposure sector_check = _check_sector_exposure(order, config, state) all_checks.append(sector_check) if sector_check.result == RiskCheckResult.FAIL: rejection_reasons.append(sector_check.message) # 4. Daily loss limits daily_checks = _check_daily_loss(config, state) all_checks.extend(daily_checks) for c in daily_checks: if c.result == RiskCheckResult.FAIL: rejection_reasons.append(c.message) # 5. News-shock lockout lockout_check = _check_news_shock_lockout(order, config, state, now) all_checks.append(lockout_check) if lockout_check.result == RiskCheckResult.FAIL: rejection_reasons.append(lockout_check.message) # 6. Symbol cooldown cooldown_check = _check_symbol_cooldown(order, config, state, now) all_checks.append(cooldown_check) if cooldown_check.result == RiskCheckResult.FAIL: rejection_reasons.append(cooldown_check.message) # Determine eligibility and allowed mode eligible = len(rejection_reasons) == 0 allowed_mode = config.trading_mode if eligible else TradingMode.DISABLED return RiskEvaluation( recommendation_id=order.recommendation_id, ticker=order.ticker, eligible=eligible, allowed_mode=allowed_mode, checks=all_checks, rejection_reasons=rejection_reasons, config_snapshot=config, state_snapshot=state, evaluated_at=now, )