phase 14-15: docker build validation and helm deployment
This commit is contained in:
@@ -0,0 +1,101 @@
|
||||
"""Risk Engine API - FastAPI application for order risk evaluation and approval workflow."""
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import asyncpg
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.risk.approval import (
|
||||
expire_stale_approvals,
|
||||
get_approval_by_id,
|
||||
get_pending_approvals,
|
||||
review_approval,
|
||||
)
|
||||
from services.risk.engine import (
|
||||
AccountRiskState,
|
||||
PortfolioRiskConfig,
|
||||
ProposedOrder,
|
||||
RiskEvaluation,
|
||||
evaluate_order,
|
||||
)
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import setup_logging
|
||||
|
||||
config = load_config()
|
||||
pool: asyncpg.Pool | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global pool
|
||||
setup_logging("risk_engine", level=config.log_level, json_output=config.json_logs)
|
||||
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
|
||||
yield
|
||||
if pool:
|
||||
await pool.close()
|
||||
|
||||
|
||||
app = FastAPI(title="Stonks Oracle - Risk Engine", lifespan=lifespan)
|
||||
|
||||
|
||||
class EvaluateRequest(BaseModel):
|
||||
order: ProposedOrder
|
||||
config: PortfolioRiskConfig | None = None
|
||||
state: AccountRiskState | None = None
|
||||
|
||||
|
||||
@app.post("/evaluate", response_model=RiskEvaluation)
|
||||
async def evaluate(req: EvaluateRequest) -> RiskEvaluation:
|
||||
risk_config = req.config or PortfolioRiskConfig()
|
||||
return evaluate_order(req.order, risk_config, req.state)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class ReviewRequest(BaseModel):
|
||||
approved: bool
|
||||
reviewed_by: str = "operator"
|
||||
review_note: str = ""
|
||||
|
||||
|
||||
@app.get("/approvals/pending")
|
||||
async def list_pending():
|
||||
if not pool:
|
||||
raise HTTPException(503, "Database not ready")
|
||||
requests = await get_pending_approvals(pool)
|
||||
return [r.to_dict() for r in requests]
|
||||
|
||||
|
||||
@app.get("/approvals/{approval_id}")
|
||||
async def get_approval(approval_id: str):
|
||||
if not pool:
|
||||
raise HTTPException(503, "Database not ready")
|
||||
req = await get_approval_by_id(pool, approval_id)
|
||||
if not req:
|
||||
raise HTTPException(404, "Approval not found")
|
||||
return req.to_dict()
|
||||
|
||||
|
||||
@app.post("/approvals/{approval_id}/review")
|
||||
async def review(approval_id: str, body: ReviewRequest):
|
||||
if not pool:
|
||||
raise HTTPException(503, "Database not ready")
|
||||
status = await review_approval(
|
||||
pool, approval_id, body.approved, body.reviewed_by, body.review_note,
|
||||
)
|
||||
if status is None:
|
||||
raise HTTPException(404, "Approval not found or no longer pending")
|
||||
return {"approval_id": approval_id, "status": status.value}
|
||||
|
||||
|
||||
@app.post("/approvals/expire")
|
||||
async def expire():
|
||||
if not pool:
|
||||
raise HTTPException(503, "Database not ready")
|
||||
expired = await expire_stale_approvals(pool)
|
||||
return {"expired": expired}
|
||||
@@ -0,0 +1,300 @@
|
||||
"""Operator approval workflow for live trading mode.
|
||||
|
||||
When live trading is enabled and operator approval is required,
|
||||
orders are held in a pending state until an operator explicitly
|
||||
approves or rejects them. Expired approvals are treated as rejections.
|
||||
|
||||
Requirements: 8.2
|
||||
Design: Section 4.8 - Risk Engine (operator approval rules)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import asyncpg
|
||||
|
||||
from services.risk.engine import (
|
||||
OperatorApproval,
|
||||
PortfolioRiskConfig,
|
||||
TradingMode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("operator_approval")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
REJECTED = "rejected"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core logic: does this order need approval?
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def requires_approval(
|
||||
config: PortfolioRiskConfig,
|
||||
trading_mode: TradingMode | None = None,
|
||||
) -> bool:
|
||||
"""Determine whether an order requires operator approval.
|
||||
|
||||
Paper orders are auto-approved when auto_approve_paper is True.
|
||||
Live orders require approval when require_approval_for_live is True.
|
||||
Disabled mode always returns False (orders are blocked upstream).
|
||||
"""
|
||||
mode = trading_mode or config.trading_mode
|
||||
|
||||
if mode == TradingMode.DISABLED:
|
||||
return False
|
||||
|
||||
if mode == TradingMode.PAPER:
|
||||
return not config.operator_approval.auto_approve_paper
|
||||
|
||||
# Live mode
|
||||
return config.operator_approval.require_approval_for_live
|
||||
|
||||
|
||||
def compute_expiry(
|
||||
config: PortfolioRiskConfig,
|
||||
now: datetime | None = None,
|
||||
) -> datetime:
|
||||
"""Compute the expiry timestamp for a new approval request."""
|
||||
now = now or datetime.now(timezone.utc)
|
||||
return now + timedelta(minutes=config.operator_approval.approval_timeout_minutes)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Approval request model (in-memory representation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ApprovalRequest:
|
||||
"""Represents a pending operator approval request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approval_id: str | None = None,
|
||||
order_job: dict[str, Any] | None = None,
|
||||
recommendation_id: str | None = None,
|
||||
ticker: str = "",
|
||||
side: str = "buy",
|
||||
quantity: float = 0.0,
|
||||
estimated_value: float = 0.0,
|
||||
risk_evaluation_id: str | None = None,
|
||||
status: ApprovalStatus = ApprovalStatus.PENDING,
|
||||
requested_by: str = "system",
|
||||
reviewed_by: str | None = None,
|
||||
review_note: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
requested_at: datetime | None = None,
|
||||
reviewed_at: datetime | None = None,
|
||||
) -> None:
|
||||
self.approval_id = approval_id or str(uuid.uuid4())
|
||||
self.order_job = order_job or {}
|
||||
self.recommendation_id = recommendation_id
|
||||
self.ticker = ticker
|
||||
self.side = side
|
||||
self.quantity = quantity
|
||||
self.estimated_value = estimated_value
|
||||
self.risk_evaluation_id = risk_evaluation_id
|
||||
self.status = status
|
||||
self.requested_by = requested_by
|
||||
self.reviewed_by = reviewed_by
|
||||
self.review_note = review_note
|
||||
self.expires_at = expires_at or (datetime.now(timezone.utc) + timedelta(minutes=30))
|
||||
self.requested_at = requested_at or datetime.now(timezone.utc)
|
||||
self.reviewed_at = reviewed_at
|
||||
|
||||
@property
|
||||
def is_pending(self) -> bool:
|
||||
return self.status == ApprovalStatus.PENDING
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
if self.status == ApprovalStatus.EXPIRED:
|
||||
return True
|
||||
if self.status == ApprovalStatus.PENDING:
|
||||
return datetime.now(timezone.utc) >= self.expires_at
|
||||
return False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"approval_id": self.approval_id,
|
||||
"recommendation_id": self.recommendation_id,
|
||||
"ticker": self.ticker,
|
||||
"side": self.side,
|
||||
"quantity": self.quantity,
|
||||
"estimated_value": self.estimated_value,
|
||||
"risk_evaluation_id": self.risk_evaluation_id,
|
||||
"status": self.status.value,
|
||||
"requested_by": self.requested_by,
|
||||
"reviewed_by": self.reviewed_by,
|
||||
"review_note": self.review_note,
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"requested_at": self.requested_at.isoformat() if self.requested_at else None,
|
||||
"reviewed_at": self.reviewed_at.isoformat() if self.reviewed_at else None,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB persistence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_INSERT_APPROVAL = """
|
||||
INSERT INTO operator_approvals (
|
||||
id, order_job, recommendation_id, ticker, side, quantity,
|
||||
estimated_value, status, risk_evaluation_id, requested_by,
|
||||
expires_at, requested_at
|
||||
) VALUES (
|
||||
$1::uuid, $2::jsonb, $3, $4, $5, $6,
|
||||
$7, $8, $9, $10,
|
||||
$11, $12
|
||||
)
|
||||
"""
|
||||
|
||||
_UPDATE_APPROVAL_STATUS = """
|
||||
UPDATE operator_approvals
|
||||
SET status = $2, reviewed_by = $3, review_note = $4, reviewed_at = $5, updated_at = NOW()
|
||||
WHERE id = $1::uuid AND status = 'pending'
|
||||
RETURNING id, status
|
||||
"""
|
||||
|
||||
_EXPIRE_STALE_APPROVALS = """
|
||||
UPDATE operator_approvals
|
||||
SET status = 'expired', updated_at = NOW()
|
||||
WHERE status = 'pending' AND expires_at <= $1
|
||||
RETURNING id, ticker
|
||||
"""
|
||||
|
||||
_FETCH_PENDING_APPROVALS = """
|
||||
SELECT id, order_job, recommendation_id, ticker, side, quantity,
|
||||
estimated_value, status, risk_evaluation_id, requested_by,
|
||||
reviewed_by, review_note, expires_at, requested_at, reviewed_at
|
||||
FROM operator_approvals
|
||||
WHERE status = 'pending'
|
||||
ORDER BY requested_at ASC
|
||||
"""
|
||||
|
||||
_FETCH_APPROVAL_BY_ID = """
|
||||
SELECT id, order_job, recommendation_id, ticker, side, quantity,
|
||||
estimated_value, status, risk_evaluation_id, requested_by,
|
||||
reviewed_by, review_note, expires_at, requested_at, reviewed_at
|
||||
FROM operator_approvals
|
||||
WHERE id = $1::uuid
|
||||
"""
|
||||
|
||||
|
||||
def _row_to_request(row: Any) -> ApprovalRequest:
|
||||
"""Convert a DB row to an ApprovalRequest."""
|
||||
order_job = row["order_job"]
|
||||
if isinstance(order_job, str):
|
||||
order_job = json.loads(order_job)
|
||||
return ApprovalRequest(
|
||||
approval_id=str(row["id"]),
|
||||
order_job=order_job,
|
||||
recommendation_id=str(row["recommendation_id"]) if row["recommendation_id"] else None,
|
||||
ticker=row["ticker"],
|
||||
side=row["side"],
|
||||
quantity=float(row["quantity"]),
|
||||
estimated_value=float(row["estimated_value"]),
|
||||
risk_evaluation_id=str(row["risk_evaluation_id"]) if row.get("risk_evaluation_id") else None,
|
||||
status=ApprovalStatus(row["status"]),
|
||||
requested_by=row["requested_by"],
|
||||
reviewed_by=row["reviewed_by"],
|
||||
review_note=row["review_note"],
|
||||
expires_at=row["expires_at"],
|
||||
requested_at=row["requested_at"],
|
||||
reviewed_at=row["reviewed_at"],
|
||||
)
|
||||
|
||||
|
||||
async def create_approval_request(
|
||||
pool: asyncpg.Pool,
|
||||
request: ApprovalRequest,
|
||||
) -> str:
|
||||
"""Persist a new approval request. Returns the approval ID."""
|
||||
await pool.execute(
|
||||
_INSERT_APPROVAL,
|
||||
request.approval_id,
|
||||
json.dumps(request.order_job, default=str),
|
||||
request.recommendation_id,
|
||||
request.ticker,
|
||||
request.side,
|
||||
request.quantity,
|
||||
request.estimated_value,
|
||||
request.status.value,
|
||||
request.risk_evaluation_id,
|
||||
request.requested_by,
|
||||
request.expires_at,
|
||||
request.requested_at,
|
||||
)
|
||||
return request.approval_id
|
||||
|
||||
|
||||
async def review_approval(
|
||||
pool: asyncpg.Pool,
|
||||
approval_id: str,
|
||||
approved: bool,
|
||||
reviewed_by: str = "operator",
|
||||
review_note: str = "",
|
||||
) -> ApprovalStatus | None:
|
||||
"""Approve or reject a pending approval request.
|
||||
|
||||
Returns the new status, or None if the approval was not found
|
||||
or was no longer pending (already expired/reviewed).
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
new_status = ApprovalStatus.APPROVED if approved else ApprovalStatus.REJECTED
|
||||
|
||||
row = await pool.fetchrow(
|
||||
_UPDATE_APPROVAL_STATUS,
|
||||
approval_id,
|
||||
new_status.value,
|
||||
reviewed_by,
|
||||
review_note,
|
||||
now,
|
||||
)
|
||||
if row:
|
||||
return ApprovalStatus(row["status"])
|
||||
return None
|
||||
|
||||
|
||||
async def expire_stale_approvals(
|
||||
pool: asyncpg.Pool,
|
||||
now: datetime | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Mark all expired pending approvals. Returns list of expired items."""
|
||||
now = now or datetime.now(timezone.utc)
|
||||
rows = await pool.fetch(_EXPIRE_STALE_APPROVALS, now)
|
||||
return [{"id": str(r["id"]), "ticker": r["ticker"]} for r in rows]
|
||||
|
||||
|
||||
async def get_pending_approvals(
|
||||
pool: asyncpg.Pool,
|
||||
) -> list[ApprovalRequest]:
|
||||
"""Fetch all pending approval requests, oldest first."""
|
||||
rows = await pool.fetch(_FETCH_PENDING_APPROVALS)
|
||||
return [_row_to_request(r) for r in rows]
|
||||
|
||||
|
||||
async def get_approval_by_id(
|
||||
pool: asyncpg.Pool,
|
||||
approval_id: str,
|
||||
) -> ApprovalRequest | None:
|
||||
"""Fetch a single approval request by ID."""
|
||||
row = await pool.fetchrow(_FETCH_APPROVAL_BY_ID, approval_id)
|
||||
if row:
|
||||
return _row_to_request(row)
|
||||
return None
|
||||
+616
-1
@@ -1 +1,616 @@
|
||||
"""Risk engine - enforces guardrails, position limits, and trade eligibility checks."""
|
||||
"""Risk engine - portfolio and account risk configuration and enforcement.
|
||||
|
||||
Defines the configuration and state models used to enforce guardrails
|
||||
on trade execution: max position size, sector exposure, daily loss limits,
|
||||
news-shock lockouts, and operator approval rules.
|
||||
|
||||
Also implements the hard-block evaluation logic that decides whether a
|
||||
proposed order is allowed before it reaches the broker adapter.
|
||||
|
||||
Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
|
||||
Design: Section 4.8 - Risk Engine
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TradingMode(str, Enum):
|
||||
"""Execution environment separation (Requirement 8.1)."""
|
||||
PAPER = "paper"
|
||||
LIVE = "live"
|
||||
DISABLED = "disabled"
|
||||
|
||||
|
||||
class RiskCheckResult(str, Enum):
|
||||
"""Outcome of a single risk check."""
|
||||
PASS = "pass"
|
||||
FAIL = "fail"
|
||||
WARN = "warn"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Portfolio-level risk configuration (Requirement 8.2, 8.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PositionLimits(BaseModel):
|
||||
"""Per-position size constraints."""
|
||||
max_position_pct: float = Field(
|
||||
default=0.05, ge=0, le=1,
|
||||
description="Maximum portfolio percentage for a single position",
|
||||
)
|
||||
max_position_value: float = Field(
|
||||
default=10_000.0, ge=0,
|
||||
description="Maximum dollar value for a single position",
|
||||
)
|
||||
max_shares_per_order: float = Field(
|
||||
default=1000.0, ge=0,
|
||||
description="Maximum shares in a single order",
|
||||
)
|
||||
|
||||
|
||||
class SectorExposureLimits(BaseModel):
|
||||
"""Sector-level concentration limits."""
|
||||
max_sector_pct: float = Field(
|
||||
default=0.25, ge=0, le=1,
|
||||
description="Maximum portfolio percentage exposed to one sector",
|
||||
)
|
||||
max_sectors: int = Field(
|
||||
default=10, ge=1,
|
||||
description="Maximum number of sectors with open positions",
|
||||
)
|
||||
|
||||
|
||||
class DailyLossLimits(BaseModel):
|
||||
"""Daily drawdown controls."""
|
||||
max_daily_loss_pct: float = Field(
|
||||
default=0.02, ge=0, le=1,
|
||||
description="Maximum portfolio loss percentage in a single day before halting",
|
||||
)
|
||||
max_daily_loss_value: float = Field(
|
||||
default=1_000.0, ge=0,
|
||||
description="Maximum dollar loss in a single day before halting",
|
||||
)
|
||||
max_daily_trades: int = Field(
|
||||
default=20, ge=0,
|
||||
description="Maximum number of trades per day",
|
||||
)
|
||||
|
||||
|
||||
class NewsShockLockout(BaseModel):
|
||||
"""News-shock lockout configuration.
|
||||
|
||||
When a symbol has a high-impact news event, trading is paused
|
||||
for a configurable cooldown period.
|
||||
"""
|
||||
enabled: bool = True
|
||||
lockout_minutes: int = Field(
|
||||
default=60, ge=0,
|
||||
description="Minutes to lock out trading after a high-impact news event",
|
||||
)
|
||||
impact_threshold: float = Field(
|
||||
default=0.80, ge=0, le=1,
|
||||
description="Minimum impact_score from document intelligence to trigger lockout",
|
||||
)
|
||||
catalyst_types: list[str] = Field(
|
||||
default_factory=lambda: ["earnings", "legal", "m_and_a"],
|
||||
description="Catalyst types that trigger lockout when above threshold",
|
||||
)
|
||||
|
||||
|
||||
class OperatorApproval(BaseModel):
|
||||
"""Operator approval workflow for live trading (Requirement 8.2)."""
|
||||
require_approval_for_live: bool = Field(
|
||||
default=True,
|
||||
description="Whether live orders require operator approval",
|
||||
)
|
||||
auto_approve_paper: bool = Field(
|
||||
default=True,
|
||||
description="Whether paper orders are auto-approved",
|
||||
)
|
||||
approval_timeout_minutes: int = Field(
|
||||
default=30, ge=1,
|
||||
description="Minutes before a pending approval expires",
|
||||
)
|
||||
|
||||
|
||||
class SymbolCooldown(BaseModel):
|
||||
"""Per-symbol cooldown after a trade."""
|
||||
cooldown_minutes: int = Field(
|
||||
default=15, ge=0,
|
||||
description="Minutes to wait before trading the same symbol again",
|
||||
)
|
||||
max_open_positions_per_symbol: int = Field(
|
||||
default=1, ge=1,
|
||||
description="Maximum concurrent open positions for a single symbol",
|
||||
)
|
||||
|
||||
|
||||
class PortfolioRiskConfig(BaseModel):
|
||||
"""Complete portfolio-level risk configuration.
|
||||
|
||||
This is the top-level config that governs all risk checks.
|
||||
Persisted in PostgreSQL and loaded at engine startup.
|
||||
"""
|
||||
config_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
name: str = "default"
|
||||
trading_mode: TradingMode = TradingMode.PAPER
|
||||
position_limits: PositionLimits = Field(default_factory=PositionLimits)
|
||||
sector_exposure: SectorExposureLimits = Field(default_factory=SectorExposureLimits)
|
||||
daily_loss: DailyLossLimits = Field(default_factory=DailyLossLimits)
|
||||
news_shock: NewsShockLockout = Field(default_factory=NewsShockLockout)
|
||||
operator_approval: OperatorApproval = Field(default_factory=OperatorApproval)
|
||||
symbol_cooldown: SymbolCooldown = Field(default_factory=SymbolCooldown)
|
||||
active: bool = True
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_db_json(self) -> dict[str, Any]:
|
||||
"""Serialize the full config to a JSON-compatible dict for DB storage."""
|
||||
return self.model_dump(mode="json")
|
||||
|
||||
@classmethod
|
||||
def from_db_json(cls, data: dict[str, Any]) -> PortfolioRiskConfig:
|
||||
"""Deserialize from a DB JSON column."""
|
||||
return cls.model_validate(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Account risk state (runtime snapshot)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AccountRiskState(BaseModel):
|
||||
"""Runtime snapshot of an account's risk posture.
|
||||
|
||||
Computed from broker positions, today's trades, and current P&L.
|
||||
Used by risk checks to evaluate whether a new order is allowed.
|
||||
"""
|
||||
account_id: str = ""
|
||||
portfolio_value: float = 0.0
|
||||
cash: float = 0.0
|
||||
buying_power: float = 0.0
|
||||
daily_pnl: float = 0.0
|
||||
daily_trade_count: int = 0
|
||||
open_position_count: int = 0
|
||||
positions_by_symbol: dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of ticker → current market value",
|
||||
)
|
||||
positions_by_sector: dict[str, float] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of sector → total market value",
|
||||
)
|
||||
last_trade_times: dict[str, datetime] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of ticker → last trade timestamp for cooldown checks",
|
||||
)
|
||||
locked_symbols: dict[str, datetime] = Field(
|
||||
default_factory=dict,
|
||||
description="Map of ticker → lockout expiry for news-shock lockouts",
|
||||
)
|
||||
snapshot_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Risk check output (Requirement 8.3 - full decision trace)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RiskCheckDetail(BaseModel):
|
||||
"""Result of a single risk check."""
|
||||
check_name: str
|
||||
result: RiskCheckResult
|
||||
message: str = ""
|
||||
threshold: float | None = None
|
||||
actual: float | None = None
|
||||
|
||||
|
||||
class RiskEvaluation(BaseModel):
|
||||
"""Complete risk evaluation for a proposed order.
|
||||
|
||||
Captures every check performed so the full decision trace
|
||||
is reproducible (Requirement 8.3).
|
||||
"""
|
||||
evaluation_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
recommendation_id: str | None = None
|
||||
ticker: str = ""
|
||||
eligible: bool = False
|
||||
allowed_mode: TradingMode = TradingMode.DISABLED
|
||||
checks: list[RiskCheckDetail] = Field(default_factory=list)
|
||||
rejection_reasons: list[str] = Field(default_factory=list)
|
||||
config_snapshot: PortfolioRiskConfig | None = None
|
||||
state_snapshot: AccountRiskState | None = None
|
||||
evaluated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
@property
|
||||
def passed(self) -> bool:
|
||||
return self.eligible and len(self.rejection_reasons) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
DEFAULT_RISK_CONFIG = PortfolioRiskConfig()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Proposed order (input to risk evaluation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProposedOrder(BaseModel):
|
||||
"""A proposed order to be evaluated by the risk engine before submission.
|
||||
|
||||
This is the input to evaluate_order(). It carries enough context
|
||||
for every risk check to run without external lookups.
|
||||
"""
|
||||
recommendation_id: str | None = None
|
||||
ticker: str
|
||||
sector: str = ""
|
||||
action: str = "buy" # buy | sell
|
||||
quantity: float = 0.0
|
||||
estimated_value: float = 0.0
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Individual risk checks (Requirement 8.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_trading_mode(
|
||||
config: PortfolioRiskConfig,
|
||||
) -> RiskCheckDetail:
|
||||
"""Block all orders when trading is disabled."""
|
||||
if config.trading_mode == TradingMode.DISABLED:
|
||||
return RiskCheckDetail(
|
||||
check_name="trading_mode",
|
||||
result=RiskCheckResult.FAIL,
|
||||
message="Trading is disabled",
|
||||
)
|
||||
return RiskCheckDetail(
|
||||
check_name="trading_mode",
|
||||
result=RiskCheckResult.PASS,
|
||||
message=f"Trading mode: {config.trading_mode.value}",
|
||||
)
|
||||
|
||||
|
||||
def _check_max_position_size(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
) -> list[RiskCheckDetail]:
|
||||
"""Enforce per-position size limits (value, percentage, shares)."""
|
||||
checks: list[RiskCheckDetail] = []
|
||||
limits = config.position_limits
|
||||
|
||||
# Check max position value
|
||||
existing_value = state.positions_by_symbol.get(order.ticker, 0.0)
|
||||
new_total_value = existing_value + order.estimated_value
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="max_position_value",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if new_total_value <= limits.max_position_value
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Position value {new_total_value:.2f} "
|
||||
f"{'within' if new_total_value <= limits.max_position_value else 'exceeds'} "
|
||||
f"limit {limits.max_position_value:.2f}"
|
||||
),
|
||||
threshold=limits.max_position_value,
|
||||
actual=new_total_value,
|
||||
))
|
||||
|
||||
# Check max position percentage of portfolio
|
||||
if state.portfolio_value > 0:
|
||||
position_pct = new_total_value / state.portfolio_value
|
||||
else:
|
||||
position_pct = 1.0 if new_total_value > 0 else 0.0
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="max_position_pct",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if position_pct <= limits.max_position_pct
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Position {position_pct:.4f} of portfolio "
|
||||
f"{'within' if position_pct <= limits.max_position_pct else 'exceeds'} "
|
||||
f"limit {limits.max_position_pct:.4f}"
|
||||
),
|
||||
threshold=limits.max_position_pct,
|
||||
actual=position_pct,
|
||||
))
|
||||
|
||||
# Check max shares per order
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="max_shares_per_order",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if order.quantity <= limits.max_shares_per_order
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Order quantity {order.quantity:.0f} "
|
||||
f"{'within' if order.quantity <= limits.max_shares_per_order else 'exceeds'} "
|
||||
f"limit {limits.max_shares_per_order:.0f}"
|
||||
),
|
||||
threshold=limits.max_shares_per_order,
|
||||
actual=order.quantity,
|
||||
))
|
||||
|
||||
return checks
|
||||
|
||||
|
||||
def _check_sector_exposure(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
) -> RiskCheckDetail:
|
||||
"""Enforce sector concentration limits."""
|
||||
limits = config.sector_exposure
|
||||
|
||||
if not order.sector:
|
||||
return RiskCheckDetail(
|
||||
check_name="sector_exposure",
|
||||
result=RiskCheckResult.WARN,
|
||||
message="No sector provided on order; skipping sector check",
|
||||
)
|
||||
|
||||
existing_sector_value = state.positions_by_sector.get(order.sector, 0.0)
|
||||
new_sector_value = existing_sector_value + order.estimated_value
|
||||
|
||||
if state.portfolio_value > 0:
|
||||
sector_pct = new_sector_value / state.portfolio_value
|
||||
else:
|
||||
sector_pct = 1.0 if new_sector_value > 0 else 0.0
|
||||
|
||||
return RiskCheckDetail(
|
||||
check_name="sector_exposure",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if sector_pct <= limits.max_sector_pct
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Sector '{order.sector}' exposure {sector_pct:.4f} "
|
||||
f"{'within' if sector_pct <= limits.max_sector_pct else 'exceeds'} "
|
||||
f"limit {limits.max_sector_pct:.4f}"
|
||||
),
|
||||
threshold=limits.max_sector_pct,
|
||||
actual=sector_pct,
|
||||
)
|
||||
|
||||
|
||||
def _check_daily_loss(
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
) -> list[RiskCheckDetail]:
|
||||
"""Enforce daily loss and trade count limits."""
|
||||
checks: list[RiskCheckDetail] = []
|
||||
limits = config.daily_loss
|
||||
|
||||
# Daily loss percentage
|
||||
if state.portfolio_value > 0:
|
||||
loss_pct = abs(min(state.daily_pnl, 0.0)) / state.portfolio_value
|
||||
else:
|
||||
loss_pct = 0.0
|
||||
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="daily_loss_pct",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if loss_pct <= limits.max_daily_loss_pct
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Daily loss {loss_pct:.4f} "
|
||||
f"{'within' if loss_pct <= limits.max_daily_loss_pct else 'exceeds'} "
|
||||
f"limit {limits.max_daily_loss_pct:.4f}"
|
||||
),
|
||||
threshold=limits.max_daily_loss_pct,
|
||||
actual=loss_pct,
|
||||
))
|
||||
|
||||
# Daily loss absolute value
|
||||
abs_loss = abs(min(state.daily_pnl, 0.0))
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="daily_loss_value",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if abs_loss <= limits.max_daily_loss_value
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Daily loss ${abs_loss:.2f} "
|
||||
f"{'within' if abs_loss <= limits.max_daily_loss_value else 'exceeds'} "
|
||||
f"limit ${limits.max_daily_loss_value:.2f}"
|
||||
),
|
||||
threshold=limits.max_daily_loss_value,
|
||||
actual=abs_loss,
|
||||
))
|
||||
|
||||
# Daily trade count
|
||||
checks.append(RiskCheckDetail(
|
||||
check_name="daily_trade_count",
|
||||
result=(
|
||||
RiskCheckResult.PASS
|
||||
if state.daily_trade_count < limits.max_daily_trades
|
||||
else RiskCheckResult.FAIL
|
||||
),
|
||||
message=(
|
||||
f"Daily trades {state.daily_trade_count} "
|
||||
f"{'within' if state.daily_trade_count < limits.max_daily_trades else 'at/exceeds'} "
|
||||
f"limit {limits.max_daily_trades}"
|
||||
),
|
||||
threshold=float(limits.max_daily_trades),
|
||||
actual=float(state.daily_trade_count),
|
||||
))
|
||||
|
||||
return checks
|
||||
|
||||
|
||||
def _check_news_shock_lockout(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
now: datetime | None = None,
|
||||
) -> RiskCheckDetail:
|
||||
"""Block trading on symbols under news-shock lockout."""
|
||||
lockout_cfg = config.news_shock
|
||||
|
||||
if not lockout_cfg.enabled:
|
||||
return RiskCheckDetail(
|
||||
check_name="news_shock_lockout",
|
||||
result=RiskCheckResult.PASS,
|
||||
message="News-shock lockout is disabled",
|
||||
)
|
||||
|
||||
now = now or datetime.now(timezone.utc)
|
||||
lockout_expiry = state.locked_symbols.get(order.ticker)
|
||||
|
||||
if lockout_expiry is not None and now < lockout_expiry:
|
||||
remaining = lockout_expiry - now
|
||||
return RiskCheckDetail(
|
||||
check_name="news_shock_lockout",
|
||||
result=RiskCheckResult.FAIL,
|
||||
message=(
|
||||
f"Symbol {order.ticker} locked out until "
|
||||
f"{lockout_expiry.isoformat()} "
|
||||
f"({remaining.total_seconds():.0f}s remaining)"
|
||||
),
|
||||
)
|
||||
|
||||
return RiskCheckDetail(
|
||||
check_name="news_shock_lockout",
|
||||
result=RiskCheckResult.PASS,
|
||||
message=f"No active lockout for {order.ticker}",
|
||||
)
|
||||
|
||||
|
||||
def _check_symbol_cooldown(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig,
|
||||
state: AccountRiskState,
|
||||
now: datetime | None = None,
|
||||
) -> RiskCheckDetail:
|
||||
"""Enforce per-symbol cooldown between trades."""
|
||||
cooldown_cfg = config.symbol_cooldown
|
||||
now = now or datetime.now(timezone.utc)
|
||||
|
||||
last_trade = state.last_trade_times.get(order.ticker)
|
||||
if last_trade is not None:
|
||||
cooldown_end = last_trade + timedelta(minutes=cooldown_cfg.cooldown_minutes)
|
||||
if now < cooldown_end:
|
||||
remaining = cooldown_end - now
|
||||
return RiskCheckDetail(
|
||||
check_name="symbol_cooldown",
|
||||
result=RiskCheckResult.FAIL,
|
||||
message=(
|
||||
f"Symbol {order.ticker} in cooldown until "
|
||||
f"{cooldown_end.isoformat()} "
|
||||
f"({remaining.total_seconds():.0f}s remaining)"
|
||||
),
|
||||
)
|
||||
|
||||
return RiskCheckDetail(
|
||||
check_name="symbol_cooldown",
|
||||
result=RiskCheckResult.PASS,
|
||||
message=f"No active cooldown for {order.ticker}",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main evaluation entry point (Requirements 8.3, 8.4, 8.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def evaluate_order(
|
||||
order: ProposedOrder,
|
||||
config: PortfolioRiskConfig = DEFAULT_RISK_CONFIG,
|
||||
state: AccountRiskState | None = None,
|
||||
now: datetime | None = None,
|
||||
) -> RiskEvaluation:
|
||||
"""Evaluate a proposed order against all risk controls.
|
||||
|
||||
Runs every hard-block check and returns a RiskEvaluation capturing
|
||||
the full decision trace (Requirement 8.3). If any check fails,
|
||||
the order is rejected before broker submission (Requirement 8.4).
|
||||
|
||||
The engine fails closed: if state is missing or ambiguous, the
|
||||
order is rejected (Requirement 8.5).
|
||||
"""
|
||||
state = state or AccountRiskState()
|
||||
now = now or datetime.now(timezone.utc)
|
||||
|
||||
all_checks: list[RiskCheckDetail] = []
|
||||
rejection_reasons: list[str] = []
|
||||
|
||||
# 1. Trading mode gate
|
||||
mode_check = _check_trading_mode(config)
|
||||
all_checks.append(mode_check)
|
||||
if mode_check.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(mode_check.message)
|
||||
|
||||
# 2. Position size limits
|
||||
position_checks = _check_max_position_size(order, config, state)
|
||||
all_checks.extend(position_checks)
|
||||
for c in position_checks:
|
||||
if c.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(c.message)
|
||||
|
||||
# 3. Sector exposure
|
||||
sector_check = _check_sector_exposure(order, config, state)
|
||||
all_checks.append(sector_check)
|
||||
if sector_check.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(sector_check.message)
|
||||
|
||||
# 4. Daily loss limits
|
||||
daily_checks = _check_daily_loss(config, state)
|
||||
all_checks.extend(daily_checks)
|
||||
for c in daily_checks:
|
||||
if c.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(c.message)
|
||||
|
||||
# 5. News-shock lockout
|
||||
lockout_check = _check_news_shock_lockout(order, config, state, now)
|
||||
all_checks.append(lockout_check)
|
||||
if lockout_check.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(lockout_check.message)
|
||||
|
||||
# 6. Symbol cooldown
|
||||
cooldown_check = _check_symbol_cooldown(order, config, state, now)
|
||||
all_checks.append(cooldown_check)
|
||||
if cooldown_check.result == RiskCheckResult.FAIL:
|
||||
rejection_reasons.append(cooldown_check.message)
|
||||
|
||||
# Determine eligibility and allowed mode
|
||||
eligible = len(rejection_reasons) == 0
|
||||
allowed_mode = config.trading_mode if eligible else TradingMode.DISABLED
|
||||
|
||||
return RiskEvaluation(
|
||||
recommendation_id=order.recommendation_id,
|
||||
ticker=order.ticker,
|
||||
eligible=eligible,
|
||||
allowed_mode=allowed_mode,
|
||||
checks=all_checks,
|
||||
rejection_reasons=rejection_reasons,
|
||||
config_snapshot=config,
|
||||
state_snapshot=state,
|
||||
evaluated_at=now,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user