"""Risk engine - portfolio and account risk configuration and enforcement. Defines the configuration and state models used to enforce guardrails on trade execution: max position size, sector exposure, daily loss limits, news-shock lockouts, and operator approval rules. Also implements the hard-block evaluation logic that decides whether a proposed order is allowed before it reaches the broker adapter. Requirements: 8.1, 8.2, 8.3, 8.4, 8.5 Design: Section 4.8 - Risk Engine """ from __future__ import annotations import math import uuid from datetime import datetime, timedelta, timezone from enum import Enum from typing import Any from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Enums # --------------------------------------------------------------------------- class TradingMode(str, Enum): """Execution environment separation (Requirement 8.1).""" PAPER = "paper" LIVE = "live" DISABLED = "disabled" class RiskCheckResult(str, Enum): """Outcome of a single risk check.""" PASS = "pass" FAIL = "fail" WARN = "warn" # --------------------------------------------------------------------------- # Portfolio-level risk configuration (Requirement 8.2, 8.4) # --------------------------------------------------------------------------- class PositionLimits(BaseModel): """Per-position size constraints.""" max_position_pct: float = Field( default=0.05, ge=0, le=1, description="Maximum portfolio percentage for a single position", ) max_position_value: float = Field( default=10_000.0, ge=0, description="Maximum dollar value for a single position", ) max_shares_per_order: float = Field( default=1000.0, ge=0, description="Maximum shares in a single order", ) class SectorExposureLimits(BaseModel): """Sector-level concentration limits.""" max_sector_pct: float = Field( default=0.25, ge=0, le=1, description="Maximum portfolio percentage exposed to one sector", ) max_sectors: int = Field( default=10, ge=1, description="Maximum number of sectors with open positions", ) class DailyLossLimits(BaseModel): """Daily drawdown controls.""" max_daily_loss_pct: float = Field( default=0.02, ge=0, le=1, description="Maximum portfolio loss percentage in a single day before halting", ) max_daily_loss_value: float = Field( default=1_000.0, ge=0, description="Maximum dollar loss in a single day before halting", ) max_daily_trades: int = Field( default=20, ge=0, description="Maximum number of trades per day", ) class NewsShockLockout(BaseModel): """News-shock lockout configuration. When a symbol has a high-impact news event, trading is paused for a configurable cooldown period. """ enabled: bool = True lockout_minutes: int = Field( default=60, ge=0, description="Minutes to lock out trading after a high-impact news event", ) impact_threshold: float = Field( default=0.80, ge=0, le=1, description="Minimum impact_score from document intelligence to trigger lockout", ) catalyst_types: list[str] = Field( default_factory=lambda: ["earnings", "legal", "m_and_a"], description="Catalyst types that trigger lockout when above threshold", ) class OperatorApproval(BaseModel): """Operator approval workflow for live trading (Requirement 8.2).""" require_approval_for_live: bool = Field( default=True, description="Whether live orders require operator approval", ) auto_approve_paper: bool = Field( default=True, description="Whether paper orders are auto-approved", ) approval_timeout_minutes: int = Field( default=30, ge=1, description="Minutes before a pending approval expires", ) class SymbolCooldown(BaseModel): """Per-symbol cooldown after a trade.""" cooldown_minutes: int = Field( default=15, ge=0, description="Minutes to wait before trading the same symbol again", ) max_open_positions_per_symbol: int = Field( default=1, ge=1, description="Maximum concurrent open positions for a single symbol", ) class PortfolioRiskConfig(BaseModel): """Complete portfolio-level risk configuration. This is the top-level config that governs all risk checks. Persisted in PostgreSQL and loaded at engine startup. """ config_id: str = Field(default_factory=lambda: str(uuid.uuid4())) name: str = "default" trading_mode: TradingMode = TradingMode.PAPER position_limits: PositionLimits = Field(default_factory=PositionLimits) sector_exposure: SectorExposureLimits = Field(default_factory=SectorExposureLimits) daily_loss: DailyLossLimits = Field(default_factory=DailyLossLimits) news_shock: NewsShockLockout = Field(default_factory=NewsShockLockout) operator_approval: OperatorApproval = Field(default_factory=OperatorApproval) symbol_cooldown: SymbolCooldown = Field(default_factory=SymbolCooldown) active: bool = True created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) def to_db_json(self) -> dict[str, Any]: """Serialize the full config to a JSON-compatible dict for DB storage.""" return self.model_dump(mode="json") @classmethod def from_db_json(cls, data: dict[str, Any]) -> PortfolioRiskConfig: """Deserialize from a DB JSON column.""" return cls.model_validate(data) # --------------------------------------------------------------------------- # Account risk state (runtime snapshot) # --------------------------------------------------------------------------- class AccountRiskState(BaseModel): """Runtime snapshot of an account's risk posture. Computed from broker positions, today's trades, and current P&L. Used by risk checks to evaluate whether a new order is allowed. """ account_id: str = "" portfolio_value: float = 0.0 cash: float = 0.0 buying_power: float = 0.0 daily_pnl: float = 0.0 daily_trade_count: int = 0 open_position_count: int = 0 positions_by_symbol: dict[str, float] = Field( default_factory=dict, description="Map of ticker → current market value", ) positions_by_sector: dict[str, float] = Field( default_factory=dict, description="Map of sector → total market value", ) last_trade_times: dict[str, datetime] = Field( default_factory=dict, description="Map of ticker → last trade timestamp for cooldown checks", ) locked_symbols: dict[str, datetime] = Field( default_factory=dict, description="Map of ticker → lockout expiry for news-shock lockouts", ) snapshot_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) # --------------------------------------------------------------------------- # Risk check output (Requirement 8.3 - full decision trace) # --------------------------------------------------------------------------- class RiskCheckDetail(BaseModel): """Result of a single risk check.""" check_name: str result: RiskCheckResult message: str = "" threshold: float | None = None actual: float | None = None class RiskEvaluation(BaseModel): """Complete risk evaluation for a proposed order. Captures every check performed so the full decision trace is reproducible (Requirement 8.3). """ evaluation_id: str = Field(default_factory=lambda: str(uuid.uuid4())) recommendation_id: str | None = None ticker: str = "" eligible: bool = False allowed_mode: TradingMode = TradingMode.DISABLED checks: list[RiskCheckDetail] = Field(default_factory=list) rejection_reasons: list[str] = Field(default_factory=list) config_snapshot: PortfolioRiskConfig | None = None state_snapshot: AccountRiskState | None = None evaluated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) @property def passed(self) -> bool: return self.eligible and len(self.rejection_reasons) == 0 # --------------------------------------------------------------------------- # Default configuration # --------------------------------------------------------------------------- DEFAULT_RISK_CONFIG = PortfolioRiskConfig() # --------------------------------------------------------------------------- # Proposed order (input to risk evaluation) # --------------------------------------------------------------------------- class ProposedOrder(BaseModel): """A proposed order to be evaluated by the risk engine before submission. This is the input to evaluate_order(). It carries enough context for every risk check to run without external lookups. """ recommendation_id: str | None = None ticker: str sector: str = "" action: str = "buy" # buy | sell quantity: float = 0.0 estimated_value: float = 0.0 confidence: float = 0.0 # --------------------------------------------------------------------------- # Order clamping — auto-scale to fit within position limits # --------------------------------------------------------------------------- def clamp_order_to_position_limits( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, ) -> ProposedOrder: """Clamp a buy order's quantity/value to fit within position limits. Instead of hard-rejecting orders that exceed max_position_pct or max_position_value, this function computes the maximum allowed order size and returns a new ProposedOrder scaled down to fit. Sell orders are returned unchanged (they reduce exposure). If the order already fits, it is returned unchanged. If the clamped quantity rounds to zero, the order is returned with quantity=0 and estimated_value=0 so the caller can reject it. """ if order.action == "sell" or order.quantity <= 0: return order limits = config.position_limits existing_value = state.positions_by_symbol.get(order.ticker, 0.0) # Compute per-share price from the order price_per_share = ( order.estimated_value / order.quantity if order.quantity > 0 and order.estimated_value > 0 else 0.0 ) if price_per_share <= 0: return order # Can't clamp without a price; let risk checks handle it # Compute the maximum additional value we can add to this position max_allowed_value = limits.max_position_value - existing_value # Also enforce max_position_pct if portfolio value is known if state.portfolio_value > 0: max_pct_value = (limits.max_position_pct * state.portfolio_value) - existing_value max_allowed_value = min(max_allowed_value, max_pct_value) # If already at or over the limit, clamp to zero if max_allowed_value <= 0: return order.model_copy(update={"quantity": 0.0, "estimated_value": 0.0}) # If the order already fits, return unchanged if order.estimated_value <= max_allowed_value: return order # Clamp: compute the maximum whole shares that fit clamped_shares = math.floor(max_allowed_value / price_per_share) # Also respect max_shares_per_order clamped_shares = min(clamped_shares, int(limits.max_shares_per_order)) if clamped_shares <= 0: return order.model_copy(update={"quantity": 0.0, "estimated_value": 0.0}) clamped_value = clamped_shares * price_per_share return order.model_copy(update={ "quantity": float(clamped_shares), "estimated_value": clamped_value, }) # --------------------------------------------------------------------------- # Individual risk checks (Requirement 8.4) # --------------------------------------------------------------------------- def _check_trading_mode( config: PortfolioRiskConfig, ) -> RiskCheckDetail: """Block all orders when trading is disabled.""" if config.trading_mode == TradingMode.DISABLED: return RiskCheckDetail( check_name="trading_mode", result=RiskCheckResult.FAIL, message="Trading is disabled", ) return RiskCheckDetail( check_name="trading_mode", result=RiskCheckResult.PASS, message=f"Trading mode: {config.trading_mode.value}", ) def _check_max_position_size( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, ) -> list[RiskCheckDetail]: """Enforce per-position size limits (value, percentage, shares).""" checks: list[RiskCheckDetail] = [] limits = config.position_limits # Check max position value existing_value = state.positions_by_symbol.get(order.ticker, 0.0) if order.action == "sell": new_total_value = max(existing_value - order.estimated_value, 0.0) else: new_total_value = existing_value + order.estimated_value # Sell orders always pass position value check — they reduce exposure if order.action == "sell": value_result = RiskCheckResult.PASS value_verb = "within (sell reduces exposure)" elif new_total_value <= limits.max_position_value: value_result = RiskCheckResult.PASS value_verb = "within" else: value_result = RiskCheckResult.FAIL value_verb = "exceeds" checks.append(RiskCheckDetail( check_name="max_position_value", result=value_result, message=( f"Position value {new_total_value:.2f} " f"{value_verb} " f"limit {limits.max_position_value:.2f}" ), threshold=limits.max_position_value, actual=new_total_value, )) # Check max position percentage of portfolio if state.portfolio_value > 0: position_pct = new_total_value / state.portfolio_value else: position_pct = 1.0 if new_total_value > 0 else 0.0 # Sell orders that reduce concentration should always pass — blocking a # sell on an over-concentrated position prevents the user from fixing it. if order.action == "sell": pct_result = RiskCheckResult.PASS pct_verb = "within (sell reduces exposure)" elif position_pct <= limits.max_position_pct: pct_result = RiskCheckResult.PASS pct_verb = "within" else: pct_result = RiskCheckResult.FAIL pct_verb = "exceeds" checks.append(RiskCheckDetail( check_name="max_position_pct", result=pct_result, message=( f"Position {position_pct:.4f} of portfolio " f"{pct_verb} " f"limit {limits.max_position_pct:.4f}" ), threshold=limits.max_position_pct, actual=position_pct, )) # Check max shares per order checks.append(RiskCheckDetail( check_name="max_shares_per_order", result=( RiskCheckResult.PASS if order.quantity <= limits.max_shares_per_order else RiskCheckResult.FAIL ), message=( f"Order quantity {order.quantity:.0f} " f"{'within' if order.quantity <= limits.max_shares_per_order else 'exceeds'} " f"limit {limits.max_shares_per_order:.0f}" ), threshold=limits.max_shares_per_order, actual=order.quantity, )) return checks def _check_sector_exposure( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, ) -> RiskCheckDetail: """Enforce sector concentration limits.""" limits = config.sector_exposure if not order.sector: return RiskCheckDetail( check_name="sector_exposure", result=RiskCheckResult.WARN, message="No sector provided on order; skipping sector check", ) existing_sector_value = state.positions_by_sector.get(order.sector, 0.0) new_sector_value = existing_sector_value + order.estimated_value if state.portfolio_value > 0: sector_pct = new_sector_value / state.portfolio_value else: sector_pct = 1.0 if new_sector_value > 0 else 0.0 return RiskCheckDetail( check_name="sector_exposure", result=( RiskCheckResult.PASS if sector_pct <= limits.max_sector_pct else RiskCheckResult.FAIL ), message=( f"Sector '{order.sector}' exposure {sector_pct:.4f} " f"{'within' if sector_pct <= limits.max_sector_pct else 'exceeds'} " f"limit {limits.max_sector_pct:.4f}" ), threshold=limits.max_sector_pct, actual=sector_pct, ) def _check_daily_loss( config: PortfolioRiskConfig, state: AccountRiskState, ) -> list[RiskCheckDetail]: """Enforce daily loss and trade count limits.""" checks: list[RiskCheckDetail] = [] limits = config.daily_loss # Daily loss percentage if state.portfolio_value > 0: loss_pct = abs(min(state.daily_pnl, 0.0)) / state.portfolio_value else: loss_pct = 0.0 checks.append(RiskCheckDetail( check_name="daily_loss_pct", result=( RiskCheckResult.PASS if loss_pct <= limits.max_daily_loss_pct else RiskCheckResult.FAIL ), message=( f"Daily loss {loss_pct:.4f} " f"{'within' if loss_pct <= limits.max_daily_loss_pct else 'exceeds'} " f"limit {limits.max_daily_loss_pct:.4f}" ), threshold=limits.max_daily_loss_pct, actual=loss_pct, )) # Daily loss absolute value abs_loss = abs(min(state.daily_pnl, 0.0)) checks.append(RiskCheckDetail( check_name="daily_loss_value", result=( RiskCheckResult.PASS if abs_loss <= limits.max_daily_loss_value else RiskCheckResult.FAIL ), message=( f"Daily loss ${abs_loss:.2f} " f"{'within' if abs_loss <= limits.max_daily_loss_value else 'exceeds'} " f"limit ${limits.max_daily_loss_value:.2f}" ), threshold=limits.max_daily_loss_value, actual=abs_loss, )) # Daily trade count checks.append(RiskCheckDetail( check_name="daily_trade_count", result=( RiskCheckResult.PASS if state.daily_trade_count < limits.max_daily_trades else RiskCheckResult.FAIL ), message=( f"Daily trades {state.daily_trade_count} " f"{'within' if state.daily_trade_count < limits.max_daily_trades else 'at/exceeds'} " f"limit {limits.max_daily_trades}" ), threshold=float(limits.max_daily_trades), actual=float(state.daily_trade_count), )) return checks def _check_news_shock_lockout( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, now: datetime | None = None, ) -> RiskCheckDetail: """Block trading on symbols under news-shock lockout.""" lockout_cfg = config.news_shock if not lockout_cfg.enabled: return RiskCheckDetail( check_name="news_shock_lockout", result=RiskCheckResult.PASS, message="News-shock lockout is disabled", ) now = now or datetime.now(timezone.utc) lockout_expiry = state.locked_symbols.get(order.ticker) if lockout_expiry is not None and now < lockout_expiry: remaining = lockout_expiry - now return RiskCheckDetail( check_name="news_shock_lockout", result=RiskCheckResult.FAIL, message=( f"Symbol {order.ticker} locked out until " f"{lockout_expiry.isoformat()} " f"({remaining.total_seconds():.0f}s remaining)" ), ) return RiskCheckDetail( check_name="news_shock_lockout", result=RiskCheckResult.PASS, message=f"No active lockout for {order.ticker}", ) def _check_symbol_cooldown( order: ProposedOrder, config: PortfolioRiskConfig, state: AccountRiskState, now: datetime | None = None, ) -> RiskCheckDetail: """Enforce per-symbol cooldown between trades.""" cooldown_cfg = config.symbol_cooldown now = now or datetime.now(timezone.utc) last_trade = state.last_trade_times.get(order.ticker) if last_trade is not None: cooldown_end = last_trade + timedelta(minutes=cooldown_cfg.cooldown_minutes) if now < cooldown_end: remaining = cooldown_end - now return RiskCheckDetail( check_name="symbol_cooldown", result=RiskCheckResult.FAIL, message=( f"Symbol {order.ticker} in cooldown until " f"{cooldown_end.isoformat()} " f"({remaining.total_seconds():.0f}s remaining)" ), ) return RiskCheckDetail( check_name="symbol_cooldown", result=RiskCheckResult.PASS, message=f"No active cooldown for {order.ticker}", ) # --------------------------------------------------------------------------- # Main evaluation entry point (Requirements 8.3, 8.4, 8.5) # --------------------------------------------------------------------------- def evaluate_order( order: ProposedOrder, config: PortfolioRiskConfig = DEFAULT_RISK_CONFIG, state: AccountRiskState | None = None, now: datetime | None = None, ) -> RiskEvaluation: """Evaluate a proposed order against all risk controls. Runs every hard-block check and returns a RiskEvaluation capturing the full decision trace (Requirement 8.3). If any check fails, the order is rejected before broker submission (Requirement 8.4). The engine fails closed: if state is missing or ambiguous, the order is rejected (Requirement 8.5). """ state = state or AccountRiskState() now = now or datetime.now(timezone.utc) all_checks: list[RiskCheckDetail] = [] rejection_reasons: list[str] = [] # 1. Trading mode gate mode_check = _check_trading_mode(config) all_checks.append(mode_check) if mode_check.result == RiskCheckResult.FAIL: rejection_reasons.append(mode_check.message) # 2. Position size limits position_checks = _check_max_position_size(order, config, state) all_checks.extend(position_checks) for c in position_checks: if c.result == RiskCheckResult.FAIL: rejection_reasons.append(c.message) # 3. Sector exposure sector_check = _check_sector_exposure(order, config, state) all_checks.append(sector_check) if sector_check.result == RiskCheckResult.FAIL: rejection_reasons.append(sector_check.message) # 4. Daily loss limits daily_checks = _check_daily_loss(config, state) all_checks.extend(daily_checks) for c in daily_checks: if c.result == RiskCheckResult.FAIL: rejection_reasons.append(c.message) # 5. News-shock lockout lockout_check = _check_news_shock_lockout(order, config, state, now) all_checks.append(lockout_check) if lockout_check.result == RiskCheckResult.FAIL: rejection_reasons.append(lockout_check.message) # 6. Symbol cooldown cooldown_check = _check_symbol_cooldown(order, config, state, now) all_checks.append(cooldown_check) if cooldown_check.result == RiskCheckResult.FAIL: rejection_reasons.append(cooldown_check.message) # Determine eligibility and allowed mode eligible = len(rejection_reasons) == 0 allowed_mode = config.trading_mode if eligible else TradingMode.DISABLED return RiskEvaluation( recommendation_id=order.recommendation_id, ticker=order.ticker, eligible=eligible, allowed_mode=allowed_mode, checks=all_checks, rejection_reasons=rejection_reasons, config_snapshot=config, state_snapshot=state, evaluated_at=now, )