Files

705 lines
24 KiB
Python

"""Risk engine - portfolio and account risk configuration and enforcement.
Defines the configuration and state models used to enforce guardrails
on trade execution: max position size, sector exposure, daily loss limits,
news-shock lockouts, and operator approval rules.
Also implements the hard-block evaluation logic that decides whether a
proposed order is allowed before it reaches the broker adapter.
Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
Design: Section 4.8 - Risk Engine
"""
from __future__ import annotations
import math
import uuid
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class TradingMode(str, Enum):
"""Execution environment separation (Requirement 8.1)."""
PAPER = "paper"
LIVE = "live"
DISABLED = "disabled"
class RiskCheckResult(str, Enum):
"""Outcome of a single risk check."""
PASS = "pass"
FAIL = "fail"
WARN = "warn"
# ---------------------------------------------------------------------------
# Portfolio-level risk configuration (Requirement 8.2, 8.4)
# ---------------------------------------------------------------------------
class PositionLimits(BaseModel):
"""Per-position size constraints."""
max_position_pct: float = Field(
default=0.05, ge=0, le=1,
description="Maximum portfolio percentage for a single position",
)
max_position_value: float = Field(
default=10_000.0, ge=0,
description="Maximum dollar value for a single position",
)
max_shares_per_order: float = Field(
default=1000.0, ge=0,
description="Maximum shares in a single order",
)
class SectorExposureLimits(BaseModel):
"""Sector-level concentration limits."""
max_sector_pct: float = Field(
default=0.25, ge=0, le=1,
description="Maximum portfolio percentage exposed to one sector",
)
max_sectors: int = Field(
default=10, ge=1,
description="Maximum number of sectors with open positions",
)
class DailyLossLimits(BaseModel):
"""Daily drawdown controls."""
max_daily_loss_pct: float = Field(
default=0.02, ge=0, le=1,
description="Maximum portfolio loss percentage in a single day before halting",
)
max_daily_loss_value: float = Field(
default=1_000.0, ge=0,
description="Maximum dollar loss in a single day before halting",
)
max_daily_trades: int = Field(
default=20, ge=0,
description="Maximum number of trades per day",
)
class NewsShockLockout(BaseModel):
"""News-shock lockout configuration.
When a symbol has a high-impact news event, trading is paused
for a configurable cooldown period.
"""
enabled: bool = True
lockout_minutes: int = Field(
default=60, ge=0,
description="Minutes to lock out trading after a high-impact news event",
)
impact_threshold: float = Field(
default=0.80, ge=0, le=1,
description="Minimum impact_score from document intelligence to trigger lockout",
)
catalyst_types: list[str] = Field(
default_factory=lambda: ["earnings", "legal", "m_and_a"],
description="Catalyst types that trigger lockout when above threshold",
)
class OperatorApproval(BaseModel):
"""Operator approval workflow for live trading (Requirement 8.2)."""
require_approval_for_live: bool = Field(
default=True,
description="Whether live orders require operator approval",
)
auto_approve_paper: bool = Field(
default=True,
description="Whether paper orders are auto-approved",
)
approval_timeout_minutes: int = Field(
default=30, ge=1,
description="Minutes before a pending approval expires",
)
class SymbolCooldown(BaseModel):
"""Per-symbol cooldown after a trade."""
cooldown_minutes: int = Field(
default=15, ge=0,
description="Minutes to wait before trading the same symbol again",
)
max_open_positions_per_symbol: int = Field(
default=1, ge=1,
description="Maximum concurrent open positions for a single symbol",
)
class PortfolioRiskConfig(BaseModel):
"""Complete portfolio-level risk configuration.
This is the top-level config that governs all risk checks.
Persisted in PostgreSQL and loaded at engine startup.
"""
config_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str = "default"
trading_mode: TradingMode = TradingMode.PAPER
position_limits: PositionLimits = Field(default_factory=PositionLimits)
sector_exposure: SectorExposureLimits = Field(default_factory=SectorExposureLimits)
daily_loss: DailyLossLimits = Field(default_factory=DailyLossLimits)
news_shock: NewsShockLockout = Field(default_factory=NewsShockLockout)
operator_approval: OperatorApproval = Field(default_factory=OperatorApproval)
symbol_cooldown: SymbolCooldown = Field(default_factory=SymbolCooldown)
active: bool = True
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
def to_db_json(self) -> dict[str, Any]:
"""Serialize the full config to a JSON-compatible dict for DB storage."""
return self.model_dump(mode="json")
@classmethod
def from_db_json(cls, data: dict[str, Any]) -> PortfolioRiskConfig:
"""Deserialize from a DB JSON column."""
return cls.model_validate(data)
# ---------------------------------------------------------------------------
# Account risk state (runtime snapshot)
# ---------------------------------------------------------------------------
class AccountRiskState(BaseModel):
"""Runtime snapshot of an account's risk posture.
Computed from broker positions, today's trades, and current P&L.
Used by risk checks to evaluate whether a new order is allowed.
"""
account_id: str = ""
portfolio_value: float = 0.0
cash: float = 0.0
buying_power: float = 0.0
daily_pnl: float = 0.0
daily_trade_count: int = 0
open_position_count: int = 0
positions_by_symbol: dict[str, float] = Field(
default_factory=dict,
description="Map of ticker → current market value",
)
positions_by_sector: dict[str, float] = Field(
default_factory=dict,
description="Map of sector → total market value",
)
last_trade_times: dict[str, datetime] = Field(
default_factory=dict,
description="Map of ticker → last trade timestamp for cooldown checks",
)
locked_symbols: dict[str, datetime] = Field(
default_factory=dict,
description="Map of ticker → lockout expiry for news-shock lockouts",
)
snapshot_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
# ---------------------------------------------------------------------------
# Risk check output (Requirement 8.3 - full decision trace)
# ---------------------------------------------------------------------------
class RiskCheckDetail(BaseModel):
"""Result of a single risk check."""
check_name: str
result: RiskCheckResult
message: str = ""
threshold: float | None = None
actual: float | None = None
class RiskEvaluation(BaseModel):
"""Complete risk evaluation for a proposed order.
Captures every check performed so the full decision trace
is reproducible (Requirement 8.3).
"""
evaluation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
recommendation_id: str | None = None
ticker: str = ""
eligible: bool = False
allowed_mode: TradingMode = TradingMode.DISABLED
checks: list[RiskCheckDetail] = Field(default_factory=list)
rejection_reasons: list[str] = Field(default_factory=list)
config_snapshot: PortfolioRiskConfig | None = None
state_snapshot: AccountRiskState | None = None
evaluated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@property
def passed(self) -> bool:
return self.eligible and len(self.rejection_reasons) == 0
# ---------------------------------------------------------------------------
# Default configuration
# ---------------------------------------------------------------------------
DEFAULT_RISK_CONFIG = PortfolioRiskConfig()
# ---------------------------------------------------------------------------
# Proposed order (input to risk evaluation)
# ---------------------------------------------------------------------------
class ProposedOrder(BaseModel):
"""A proposed order to be evaluated by the risk engine before submission.
This is the input to evaluate_order(). It carries enough context
for every risk check to run without external lookups.
"""
recommendation_id: str | None = None
ticker: str
sector: str = ""
action: str = "buy" # buy | sell
quantity: float = 0.0
estimated_value: float = 0.0
confidence: float = 0.0
# ---------------------------------------------------------------------------
# Order clamping — auto-scale to fit within position limits
# ---------------------------------------------------------------------------
def clamp_order_to_position_limits(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> ProposedOrder:
"""Clamp a buy order's quantity/value to fit within position limits.
Instead of hard-rejecting orders that exceed max_position_pct or
max_position_value, this function computes the maximum allowed
order size and returns a new ProposedOrder scaled down to fit.
Sell orders are returned unchanged (they reduce exposure).
If the order already fits, it is returned unchanged.
If the clamped quantity rounds to zero, the order is returned with
quantity=0 and estimated_value=0 so the caller can reject it.
"""
if order.action == "sell" or order.quantity <= 0:
return order
limits = config.position_limits
existing_value = state.positions_by_symbol.get(order.ticker, 0.0)
# Compute per-share price from the order
price_per_share = (
order.estimated_value / order.quantity
if order.quantity > 0 and order.estimated_value > 0
else 0.0
)
if price_per_share <= 0:
return order # Can't clamp without a price; let risk checks handle it
# Compute the maximum additional value we can add to this position
max_allowed_value = limits.max_position_value - existing_value
# Also enforce max_position_pct if portfolio value is known
if state.portfolio_value > 0:
max_pct_value = (limits.max_position_pct * state.portfolio_value) - existing_value
max_allowed_value = min(max_allowed_value, max_pct_value)
# If already at or over the limit, clamp to zero
if max_allowed_value <= 0:
return order.model_copy(update={"quantity": 0.0, "estimated_value": 0.0})
# If the order already fits, return unchanged
if order.estimated_value <= max_allowed_value:
return order
# Clamp: compute the maximum whole shares that fit
clamped_shares = math.floor(max_allowed_value / price_per_share)
# Also respect max_shares_per_order
clamped_shares = min(clamped_shares, int(limits.max_shares_per_order))
if clamped_shares <= 0:
return order.model_copy(update={"quantity": 0.0, "estimated_value": 0.0})
clamped_value = clamped_shares * price_per_share
return order.model_copy(update={
"quantity": float(clamped_shares),
"estimated_value": clamped_value,
})
# ---------------------------------------------------------------------------
# Individual risk checks (Requirement 8.4)
# ---------------------------------------------------------------------------
def _check_trading_mode(
config: PortfolioRiskConfig,
) -> RiskCheckDetail:
"""Block all orders when trading is disabled."""
if config.trading_mode == TradingMode.DISABLED:
return RiskCheckDetail(
check_name="trading_mode",
result=RiskCheckResult.FAIL,
message="Trading is disabled",
)
return RiskCheckDetail(
check_name="trading_mode",
result=RiskCheckResult.PASS,
message=f"Trading mode: {config.trading_mode.value}",
)
def _check_max_position_size(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> list[RiskCheckDetail]:
"""Enforce per-position size limits (value, percentage, shares)."""
checks: list[RiskCheckDetail] = []
limits = config.position_limits
# Check max position value
existing_value = state.positions_by_symbol.get(order.ticker, 0.0)
if order.action == "sell":
new_total_value = max(existing_value - order.estimated_value, 0.0)
else:
new_total_value = existing_value + order.estimated_value
# Sell orders always pass position value check — they reduce exposure
if order.action == "sell":
value_result = RiskCheckResult.PASS
value_verb = "within (sell reduces exposure)"
elif new_total_value <= limits.max_position_value:
value_result = RiskCheckResult.PASS
value_verb = "within"
else:
value_result = RiskCheckResult.FAIL
value_verb = "exceeds"
checks.append(RiskCheckDetail(
check_name="max_position_value",
result=value_result,
message=(
f"Position value {new_total_value:.2f} "
f"{value_verb} "
f"limit {limits.max_position_value:.2f}"
),
threshold=limits.max_position_value,
actual=new_total_value,
))
# Check max position percentage of portfolio
if state.portfolio_value > 0:
position_pct = new_total_value / state.portfolio_value
else:
position_pct = 1.0 if new_total_value > 0 else 0.0
# Sell orders that reduce concentration should always pass — blocking a
# sell on an over-concentrated position prevents the user from fixing it.
if order.action == "sell":
pct_result = RiskCheckResult.PASS
pct_verb = "within (sell reduces exposure)"
elif position_pct <= limits.max_position_pct:
pct_result = RiskCheckResult.PASS
pct_verb = "within"
else:
pct_result = RiskCheckResult.FAIL
pct_verb = "exceeds"
checks.append(RiskCheckDetail(
check_name="max_position_pct",
result=pct_result,
message=(
f"Position {position_pct:.4f} of portfolio "
f"{pct_verb} "
f"limit {limits.max_position_pct:.4f}"
),
threshold=limits.max_position_pct,
actual=position_pct,
))
# Check max shares per order
checks.append(RiskCheckDetail(
check_name="max_shares_per_order",
result=(
RiskCheckResult.PASS
if order.quantity <= limits.max_shares_per_order
else RiskCheckResult.FAIL
),
message=(
f"Order quantity {order.quantity:.0f} "
f"{'within' if order.quantity <= limits.max_shares_per_order else 'exceeds'} "
f"limit {limits.max_shares_per_order:.0f}"
),
threshold=limits.max_shares_per_order,
actual=order.quantity,
))
return checks
def _check_sector_exposure(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> RiskCheckDetail:
"""Enforce sector concentration limits."""
limits = config.sector_exposure
if not order.sector:
return RiskCheckDetail(
check_name="sector_exposure",
result=RiskCheckResult.WARN,
message="No sector provided on order; skipping sector check",
)
existing_sector_value = state.positions_by_sector.get(order.sector, 0.0)
new_sector_value = existing_sector_value + order.estimated_value
if state.portfolio_value > 0:
sector_pct = new_sector_value / state.portfolio_value
else:
sector_pct = 1.0 if new_sector_value > 0 else 0.0
return RiskCheckDetail(
check_name="sector_exposure",
result=(
RiskCheckResult.PASS
if sector_pct <= limits.max_sector_pct
else RiskCheckResult.FAIL
),
message=(
f"Sector '{order.sector}' exposure {sector_pct:.4f} "
f"{'within' if sector_pct <= limits.max_sector_pct else 'exceeds'} "
f"limit {limits.max_sector_pct:.4f}"
),
threshold=limits.max_sector_pct,
actual=sector_pct,
)
def _check_daily_loss(
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> list[RiskCheckDetail]:
"""Enforce daily loss and trade count limits."""
checks: list[RiskCheckDetail] = []
limits = config.daily_loss
# Daily loss percentage
if state.portfolio_value > 0:
loss_pct = abs(min(state.daily_pnl, 0.0)) / state.portfolio_value
else:
loss_pct = 0.0
checks.append(RiskCheckDetail(
check_name="daily_loss_pct",
result=(
RiskCheckResult.PASS
if loss_pct <= limits.max_daily_loss_pct
else RiskCheckResult.FAIL
),
message=(
f"Daily loss {loss_pct:.4f} "
f"{'within' if loss_pct <= limits.max_daily_loss_pct else 'exceeds'} "
f"limit {limits.max_daily_loss_pct:.4f}"
),
threshold=limits.max_daily_loss_pct,
actual=loss_pct,
))
# Daily loss absolute value
abs_loss = abs(min(state.daily_pnl, 0.0))
checks.append(RiskCheckDetail(
check_name="daily_loss_value",
result=(
RiskCheckResult.PASS
if abs_loss <= limits.max_daily_loss_value
else RiskCheckResult.FAIL
),
message=(
f"Daily loss ${abs_loss:.2f} "
f"{'within' if abs_loss <= limits.max_daily_loss_value else 'exceeds'} "
f"limit ${limits.max_daily_loss_value:.2f}"
),
threshold=limits.max_daily_loss_value,
actual=abs_loss,
))
# Daily trade count
checks.append(RiskCheckDetail(
check_name="daily_trade_count",
result=(
RiskCheckResult.PASS
if state.daily_trade_count < limits.max_daily_trades
else RiskCheckResult.FAIL
),
message=(
f"Daily trades {state.daily_trade_count} "
f"{'within' if state.daily_trade_count < limits.max_daily_trades else 'at/exceeds'} "
f"limit {limits.max_daily_trades}"
),
threshold=float(limits.max_daily_trades),
actual=float(state.daily_trade_count),
))
return checks
def _check_news_shock_lockout(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
now: datetime | None = None,
) -> RiskCheckDetail:
"""Block trading on symbols under news-shock lockout."""
lockout_cfg = config.news_shock
if not lockout_cfg.enabled:
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.PASS,
message="News-shock lockout is disabled",
)
now = now or datetime.now(timezone.utc)
lockout_expiry = state.locked_symbols.get(order.ticker)
if lockout_expiry is not None and now < lockout_expiry:
remaining = lockout_expiry - now
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.FAIL,
message=(
f"Symbol {order.ticker} locked out until "
f"{lockout_expiry.isoformat()} "
f"({remaining.total_seconds():.0f}s remaining)"
),
)
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.PASS,
message=f"No active lockout for {order.ticker}",
)
def _check_symbol_cooldown(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
now: datetime | None = None,
) -> RiskCheckDetail:
"""Enforce per-symbol cooldown between trades."""
cooldown_cfg = config.symbol_cooldown
now = now or datetime.now(timezone.utc)
last_trade = state.last_trade_times.get(order.ticker)
if last_trade is not None:
cooldown_end = last_trade + timedelta(minutes=cooldown_cfg.cooldown_minutes)
if now < cooldown_end:
remaining = cooldown_end - now
return RiskCheckDetail(
check_name="symbol_cooldown",
result=RiskCheckResult.FAIL,
message=(
f"Symbol {order.ticker} in cooldown until "
f"{cooldown_end.isoformat()} "
f"({remaining.total_seconds():.0f}s remaining)"
),
)
return RiskCheckDetail(
check_name="symbol_cooldown",
result=RiskCheckResult.PASS,
message=f"No active cooldown for {order.ticker}",
)
# ---------------------------------------------------------------------------
# Main evaluation entry point (Requirements 8.3, 8.4, 8.5)
# ---------------------------------------------------------------------------
def evaluate_order(
order: ProposedOrder,
config: PortfolioRiskConfig = DEFAULT_RISK_CONFIG,
state: AccountRiskState | None = None,
now: datetime | None = None,
) -> RiskEvaluation:
"""Evaluate a proposed order against all risk controls.
Runs every hard-block check and returns a RiskEvaluation capturing
the full decision trace (Requirement 8.3). If any check fails,
the order is rejected before broker submission (Requirement 8.4).
The engine fails closed: if state is missing or ambiguous, the
order is rejected (Requirement 8.5).
"""
state = state or AccountRiskState()
now = now or datetime.now(timezone.utc)
all_checks: list[RiskCheckDetail] = []
rejection_reasons: list[str] = []
# 1. Trading mode gate
mode_check = _check_trading_mode(config)
all_checks.append(mode_check)
if mode_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(mode_check.message)
# 2. Position size limits
position_checks = _check_max_position_size(order, config, state)
all_checks.extend(position_checks)
for c in position_checks:
if c.result == RiskCheckResult.FAIL:
rejection_reasons.append(c.message)
# 3. Sector exposure
sector_check = _check_sector_exposure(order, config, state)
all_checks.append(sector_check)
if sector_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(sector_check.message)
# 4. Daily loss limits
daily_checks = _check_daily_loss(config, state)
all_checks.extend(daily_checks)
for c in daily_checks:
if c.result == RiskCheckResult.FAIL:
rejection_reasons.append(c.message)
# 5. News-shock lockout
lockout_check = _check_news_shock_lockout(order, config, state, now)
all_checks.append(lockout_check)
if lockout_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(lockout_check.message)
# 6. Symbol cooldown
cooldown_check = _check_symbol_cooldown(order, config, state, now)
all_checks.append(cooldown_check)
if cooldown_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(cooldown_check.message)
# Determine eligibility and allowed mode
eligible = len(rejection_reasons) == 0
allowed_mode = config.trading_mode if eligible else TradingMode.DISABLED
return RiskEvaluation(
recommendation_id=order.recommendation_id,
ticker=order.ticker,
eligible=eligible,
allowed_mode=allowed_mode,
checks=all_checks,
rejection_reasons=rejection_reasons,
config_snapshot=config,
state_snapshot=state,
evaluated_at=now,
)