Files
stonks-oracle/services/risk/engine.py
T

617 lines
20 KiB
Python

"""Risk engine - portfolio and account risk configuration and enforcement.
Defines the configuration and state models used to enforce guardrails
on trade execution: max position size, sector exposure, daily loss limits,
news-shock lockouts, and operator approval rules.
Also implements the hard-block evaluation logic that decides whether a
proposed order is allowed before it reaches the broker adapter.
Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
Design: Section 4.8 - Risk Engine
"""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class TradingMode(str, Enum):
"""Execution environment separation (Requirement 8.1)."""
PAPER = "paper"
LIVE = "live"
DISABLED = "disabled"
class RiskCheckResult(str, Enum):
"""Outcome of a single risk check."""
PASS = "pass"
FAIL = "fail"
WARN = "warn"
# ---------------------------------------------------------------------------
# Portfolio-level risk configuration (Requirement 8.2, 8.4)
# ---------------------------------------------------------------------------
class PositionLimits(BaseModel):
"""Per-position size constraints."""
max_position_pct: float = Field(
default=0.05, ge=0, le=1,
description="Maximum portfolio percentage for a single position",
)
max_position_value: float = Field(
default=10_000.0, ge=0,
description="Maximum dollar value for a single position",
)
max_shares_per_order: float = Field(
default=1000.0, ge=0,
description="Maximum shares in a single order",
)
class SectorExposureLimits(BaseModel):
"""Sector-level concentration limits."""
max_sector_pct: float = Field(
default=0.25, ge=0, le=1,
description="Maximum portfolio percentage exposed to one sector",
)
max_sectors: int = Field(
default=10, ge=1,
description="Maximum number of sectors with open positions",
)
class DailyLossLimits(BaseModel):
"""Daily drawdown controls."""
max_daily_loss_pct: float = Field(
default=0.02, ge=0, le=1,
description="Maximum portfolio loss percentage in a single day before halting",
)
max_daily_loss_value: float = Field(
default=1_000.0, ge=0,
description="Maximum dollar loss in a single day before halting",
)
max_daily_trades: int = Field(
default=20, ge=0,
description="Maximum number of trades per day",
)
class NewsShockLockout(BaseModel):
"""News-shock lockout configuration.
When a symbol has a high-impact news event, trading is paused
for a configurable cooldown period.
"""
enabled: bool = True
lockout_minutes: int = Field(
default=60, ge=0,
description="Minutes to lock out trading after a high-impact news event",
)
impact_threshold: float = Field(
default=0.80, ge=0, le=1,
description="Minimum impact_score from document intelligence to trigger lockout",
)
catalyst_types: list[str] = Field(
default_factory=lambda: ["earnings", "legal", "m_and_a"],
description="Catalyst types that trigger lockout when above threshold",
)
class OperatorApproval(BaseModel):
"""Operator approval workflow for live trading (Requirement 8.2)."""
require_approval_for_live: bool = Field(
default=True,
description="Whether live orders require operator approval",
)
auto_approve_paper: bool = Field(
default=True,
description="Whether paper orders are auto-approved",
)
approval_timeout_minutes: int = Field(
default=30, ge=1,
description="Minutes before a pending approval expires",
)
class SymbolCooldown(BaseModel):
"""Per-symbol cooldown after a trade."""
cooldown_minutes: int = Field(
default=15, ge=0,
description="Minutes to wait before trading the same symbol again",
)
max_open_positions_per_symbol: int = Field(
default=1, ge=1,
description="Maximum concurrent open positions for a single symbol",
)
class PortfolioRiskConfig(BaseModel):
"""Complete portfolio-level risk configuration.
This is the top-level config that governs all risk checks.
Persisted in PostgreSQL and loaded at engine startup.
"""
config_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str = "default"
trading_mode: TradingMode = TradingMode.PAPER
position_limits: PositionLimits = Field(default_factory=PositionLimits)
sector_exposure: SectorExposureLimits = Field(default_factory=SectorExposureLimits)
daily_loss: DailyLossLimits = Field(default_factory=DailyLossLimits)
news_shock: NewsShockLockout = Field(default_factory=NewsShockLockout)
operator_approval: OperatorApproval = Field(default_factory=OperatorApproval)
symbol_cooldown: SymbolCooldown = Field(default_factory=SymbolCooldown)
active: bool = True
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
def to_db_json(self) -> dict[str, Any]:
"""Serialize the full config to a JSON-compatible dict for DB storage."""
return self.model_dump(mode="json")
@classmethod
def from_db_json(cls, data: dict[str, Any]) -> PortfolioRiskConfig:
"""Deserialize from a DB JSON column."""
return cls.model_validate(data)
# ---------------------------------------------------------------------------
# Account risk state (runtime snapshot)
# ---------------------------------------------------------------------------
class AccountRiskState(BaseModel):
"""Runtime snapshot of an account's risk posture.
Computed from broker positions, today's trades, and current P&L.
Used by risk checks to evaluate whether a new order is allowed.
"""
account_id: str = ""
portfolio_value: float = 0.0
cash: float = 0.0
buying_power: float = 0.0
daily_pnl: float = 0.0
daily_trade_count: int = 0
open_position_count: int = 0
positions_by_symbol: dict[str, float] = Field(
default_factory=dict,
description="Map of ticker → current market value",
)
positions_by_sector: dict[str, float] = Field(
default_factory=dict,
description="Map of sector → total market value",
)
last_trade_times: dict[str, datetime] = Field(
default_factory=dict,
description="Map of ticker → last trade timestamp for cooldown checks",
)
locked_symbols: dict[str, datetime] = Field(
default_factory=dict,
description="Map of ticker → lockout expiry for news-shock lockouts",
)
snapshot_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
# ---------------------------------------------------------------------------
# Risk check output (Requirement 8.3 - full decision trace)
# ---------------------------------------------------------------------------
class RiskCheckDetail(BaseModel):
"""Result of a single risk check."""
check_name: str
result: RiskCheckResult
message: str = ""
threshold: float | None = None
actual: float | None = None
class RiskEvaluation(BaseModel):
"""Complete risk evaluation for a proposed order.
Captures every check performed so the full decision trace
is reproducible (Requirement 8.3).
"""
evaluation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
recommendation_id: str | None = None
ticker: str = ""
eligible: bool = False
allowed_mode: TradingMode = TradingMode.DISABLED
checks: list[RiskCheckDetail] = Field(default_factory=list)
rejection_reasons: list[str] = Field(default_factory=list)
config_snapshot: PortfolioRiskConfig | None = None
state_snapshot: AccountRiskState | None = None
evaluated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@property
def passed(self) -> bool:
return self.eligible and len(self.rejection_reasons) == 0
# ---------------------------------------------------------------------------
# Default configuration
# ---------------------------------------------------------------------------
DEFAULT_RISK_CONFIG = PortfolioRiskConfig()
# ---------------------------------------------------------------------------
# Proposed order (input to risk evaluation)
# ---------------------------------------------------------------------------
class ProposedOrder(BaseModel):
"""A proposed order to be evaluated by the risk engine before submission.
This is the input to evaluate_order(). It carries enough context
for every risk check to run without external lookups.
"""
recommendation_id: str | None = None
ticker: str
sector: str = ""
action: str = "buy" # buy | sell
quantity: float = 0.0
estimated_value: float = 0.0
confidence: float = 0.0
# ---------------------------------------------------------------------------
# Individual risk checks (Requirement 8.4)
# ---------------------------------------------------------------------------
def _check_trading_mode(
config: PortfolioRiskConfig,
) -> RiskCheckDetail:
"""Block all orders when trading is disabled."""
if config.trading_mode == TradingMode.DISABLED:
return RiskCheckDetail(
check_name="trading_mode",
result=RiskCheckResult.FAIL,
message="Trading is disabled",
)
return RiskCheckDetail(
check_name="trading_mode",
result=RiskCheckResult.PASS,
message=f"Trading mode: {config.trading_mode.value}",
)
def _check_max_position_size(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> list[RiskCheckDetail]:
"""Enforce per-position size limits (value, percentage, shares)."""
checks: list[RiskCheckDetail] = []
limits = config.position_limits
# Check max position value
existing_value = state.positions_by_symbol.get(order.ticker, 0.0)
new_total_value = existing_value + order.estimated_value
checks.append(RiskCheckDetail(
check_name="max_position_value",
result=(
RiskCheckResult.PASS
if new_total_value <= limits.max_position_value
else RiskCheckResult.FAIL
),
message=(
f"Position value {new_total_value:.2f} "
f"{'within' if new_total_value <= limits.max_position_value else 'exceeds'} "
f"limit {limits.max_position_value:.2f}"
),
threshold=limits.max_position_value,
actual=new_total_value,
))
# Check max position percentage of portfolio
if state.portfolio_value > 0:
position_pct = new_total_value / state.portfolio_value
else:
position_pct = 1.0 if new_total_value > 0 else 0.0
checks.append(RiskCheckDetail(
check_name="max_position_pct",
result=(
RiskCheckResult.PASS
if position_pct <= limits.max_position_pct
else RiskCheckResult.FAIL
),
message=(
f"Position {position_pct:.4f} of portfolio "
f"{'within' if position_pct <= limits.max_position_pct else 'exceeds'} "
f"limit {limits.max_position_pct:.4f}"
),
threshold=limits.max_position_pct,
actual=position_pct,
))
# Check max shares per order
checks.append(RiskCheckDetail(
check_name="max_shares_per_order",
result=(
RiskCheckResult.PASS
if order.quantity <= limits.max_shares_per_order
else RiskCheckResult.FAIL
),
message=(
f"Order quantity {order.quantity:.0f} "
f"{'within' if order.quantity <= limits.max_shares_per_order else 'exceeds'} "
f"limit {limits.max_shares_per_order:.0f}"
),
threshold=limits.max_shares_per_order,
actual=order.quantity,
))
return checks
def _check_sector_exposure(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> RiskCheckDetail:
"""Enforce sector concentration limits."""
limits = config.sector_exposure
if not order.sector:
return RiskCheckDetail(
check_name="sector_exposure",
result=RiskCheckResult.WARN,
message="No sector provided on order; skipping sector check",
)
existing_sector_value = state.positions_by_sector.get(order.sector, 0.0)
new_sector_value = existing_sector_value + order.estimated_value
if state.portfolio_value > 0:
sector_pct = new_sector_value / state.portfolio_value
else:
sector_pct = 1.0 if new_sector_value > 0 else 0.0
return RiskCheckDetail(
check_name="sector_exposure",
result=(
RiskCheckResult.PASS
if sector_pct <= limits.max_sector_pct
else RiskCheckResult.FAIL
),
message=(
f"Sector '{order.sector}' exposure {sector_pct:.4f} "
f"{'within' if sector_pct <= limits.max_sector_pct else 'exceeds'} "
f"limit {limits.max_sector_pct:.4f}"
),
threshold=limits.max_sector_pct,
actual=sector_pct,
)
def _check_daily_loss(
config: PortfolioRiskConfig,
state: AccountRiskState,
) -> list[RiskCheckDetail]:
"""Enforce daily loss and trade count limits."""
checks: list[RiskCheckDetail] = []
limits = config.daily_loss
# Daily loss percentage
if state.portfolio_value > 0:
loss_pct = abs(min(state.daily_pnl, 0.0)) / state.portfolio_value
else:
loss_pct = 0.0
checks.append(RiskCheckDetail(
check_name="daily_loss_pct",
result=(
RiskCheckResult.PASS
if loss_pct <= limits.max_daily_loss_pct
else RiskCheckResult.FAIL
),
message=(
f"Daily loss {loss_pct:.4f} "
f"{'within' if loss_pct <= limits.max_daily_loss_pct else 'exceeds'} "
f"limit {limits.max_daily_loss_pct:.4f}"
),
threshold=limits.max_daily_loss_pct,
actual=loss_pct,
))
# Daily loss absolute value
abs_loss = abs(min(state.daily_pnl, 0.0))
checks.append(RiskCheckDetail(
check_name="daily_loss_value",
result=(
RiskCheckResult.PASS
if abs_loss <= limits.max_daily_loss_value
else RiskCheckResult.FAIL
),
message=(
f"Daily loss ${abs_loss:.2f} "
f"{'within' if abs_loss <= limits.max_daily_loss_value else 'exceeds'} "
f"limit ${limits.max_daily_loss_value:.2f}"
),
threshold=limits.max_daily_loss_value,
actual=abs_loss,
))
# Daily trade count
checks.append(RiskCheckDetail(
check_name="daily_trade_count",
result=(
RiskCheckResult.PASS
if state.daily_trade_count < limits.max_daily_trades
else RiskCheckResult.FAIL
),
message=(
f"Daily trades {state.daily_trade_count} "
f"{'within' if state.daily_trade_count < limits.max_daily_trades else 'at/exceeds'} "
f"limit {limits.max_daily_trades}"
),
threshold=float(limits.max_daily_trades),
actual=float(state.daily_trade_count),
))
return checks
def _check_news_shock_lockout(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
now: datetime | None = None,
) -> RiskCheckDetail:
"""Block trading on symbols under news-shock lockout."""
lockout_cfg = config.news_shock
if not lockout_cfg.enabled:
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.PASS,
message="News-shock lockout is disabled",
)
now = now or datetime.now(timezone.utc)
lockout_expiry = state.locked_symbols.get(order.ticker)
if lockout_expiry is not None and now < lockout_expiry:
remaining = lockout_expiry - now
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.FAIL,
message=(
f"Symbol {order.ticker} locked out until "
f"{lockout_expiry.isoformat()} "
f"({remaining.total_seconds():.0f}s remaining)"
),
)
return RiskCheckDetail(
check_name="news_shock_lockout",
result=RiskCheckResult.PASS,
message=f"No active lockout for {order.ticker}",
)
def _check_symbol_cooldown(
order: ProposedOrder,
config: PortfolioRiskConfig,
state: AccountRiskState,
now: datetime | None = None,
) -> RiskCheckDetail:
"""Enforce per-symbol cooldown between trades."""
cooldown_cfg = config.symbol_cooldown
now = now or datetime.now(timezone.utc)
last_trade = state.last_trade_times.get(order.ticker)
if last_trade is not None:
cooldown_end = last_trade + timedelta(minutes=cooldown_cfg.cooldown_minutes)
if now < cooldown_end:
remaining = cooldown_end - now
return RiskCheckDetail(
check_name="symbol_cooldown",
result=RiskCheckResult.FAIL,
message=(
f"Symbol {order.ticker} in cooldown until "
f"{cooldown_end.isoformat()} "
f"({remaining.total_seconds():.0f}s remaining)"
),
)
return RiskCheckDetail(
check_name="symbol_cooldown",
result=RiskCheckResult.PASS,
message=f"No active cooldown for {order.ticker}",
)
# ---------------------------------------------------------------------------
# Main evaluation entry point (Requirements 8.3, 8.4, 8.5)
# ---------------------------------------------------------------------------
def evaluate_order(
order: ProposedOrder,
config: PortfolioRiskConfig = DEFAULT_RISK_CONFIG,
state: AccountRiskState | None = None,
now: datetime | None = None,
) -> RiskEvaluation:
"""Evaluate a proposed order against all risk controls.
Runs every hard-block check and returns a RiskEvaluation capturing
the full decision trace (Requirement 8.3). If any check fails,
the order is rejected before broker submission (Requirement 8.4).
The engine fails closed: if state is missing or ambiguous, the
order is rejected (Requirement 8.5).
"""
state = state or AccountRiskState()
now = now or datetime.now(timezone.utc)
all_checks: list[RiskCheckDetail] = []
rejection_reasons: list[str] = []
# 1. Trading mode gate
mode_check = _check_trading_mode(config)
all_checks.append(mode_check)
if mode_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(mode_check.message)
# 2. Position size limits
position_checks = _check_max_position_size(order, config, state)
all_checks.extend(position_checks)
for c in position_checks:
if c.result == RiskCheckResult.FAIL:
rejection_reasons.append(c.message)
# 3. Sector exposure
sector_check = _check_sector_exposure(order, config, state)
all_checks.append(sector_check)
if sector_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(sector_check.message)
# 4. Daily loss limits
daily_checks = _check_daily_loss(config, state)
all_checks.extend(daily_checks)
for c in daily_checks:
if c.result == RiskCheckResult.FAIL:
rejection_reasons.append(c.message)
# 5. News-shock lockout
lockout_check = _check_news_shock_lockout(order, config, state, now)
all_checks.append(lockout_check)
if lockout_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(lockout_check.message)
# 6. Symbol cooldown
cooldown_check = _check_symbol_cooldown(order, config, state, now)
all_checks.append(cooldown_check)
if cooldown_check.result == RiskCheckResult.FAIL:
rejection_reasons.append(cooldown_check.message)
# Determine eligibility and allowed mode
eligible = len(rejection_reasons) == 0
allowed_mode = config.trading_mode if eligible else TradingMode.DISABLED
return RiskEvaluation(
recommendation_id=order.recommendation_id,
ticker=order.ticker,
eligible=eligible,
allowed_mode=allowed_mode,
checks=all_checks,
rejection_reasons=rejection_reasons,
config_snapshot=config,
state_snapshot=state,
evaluated_at=now,
)