432 lines
17 KiB
Python
432 lines
17 KiB
Python
"""Tests for the portfolio and account risk configuration model and enforcement."""
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from services.risk.engine import (
|
|
DEFAULT_RISK_CONFIG,
|
|
AccountRiskState,
|
|
DailyLossLimits,
|
|
NewsShockLockout,
|
|
OperatorApproval,
|
|
PortfolioRiskConfig,
|
|
PositionLimits,
|
|
ProposedOrder,
|
|
RiskCheckDetail,
|
|
RiskCheckResult,
|
|
RiskEvaluation,
|
|
SectorExposureLimits,
|
|
SymbolCooldown,
|
|
TradingMode,
|
|
evaluate_order,
|
|
)
|
|
|
|
|
|
def test_default_risk_config_is_paper_mode():
|
|
"""Default config should be paper trading mode."""
|
|
cfg = PortfolioRiskConfig()
|
|
assert cfg.trading_mode == TradingMode.PAPER
|
|
assert cfg.active is True
|
|
|
|
|
|
def test_position_limits_defaults():
|
|
limits = PositionLimits()
|
|
assert limits.max_position_pct == 0.05
|
|
assert limits.max_position_value == 10_000.0
|
|
assert limits.max_shares_per_order == 1000.0
|
|
|
|
|
|
def test_sector_exposure_defaults():
|
|
limits = SectorExposureLimits()
|
|
assert limits.max_sector_pct == 0.25
|
|
assert limits.max_sectors == 10
|
|
|
|
|
|
def test_daily_loss_defaults():
|
|
limits = DailyLossLimits()
|
|
assert limits.max_daily_loss_pct == 0.02
|
|
assert limits.max_daily_loss_value == 1_000.0
|
|
assert limits.max_daily_trades == 20
|
|
|
|
|
|
def test_news_shock_lockout_defaults():
|
|
lockout = NewsShockLockout()
|
|
assert lockout.enabled is True
|
|
assert lockout.lockout_minutes == 60
|
|
assert lockout.impact_threshold == 0.80
|
|
assert "earnings" in lockout.catalyst_types
|
|
|
|
|
|
def test_operator_approval_defaults():
|
|
approval = OperatorApproval()
|
|
assert approval.require_approval_for_live is True
|
|
assert approval.auto_approve_paper is True
|
|
assert approval.approval_timeout_minutes == 30
|
|
|
|
|
|
def test_symbol_cooldown_defaults():
|
|
cooldown = SymbolCooldown()
|
|
assert cooldown.cooldown_minutes == 15
|
|
assert cooldown.max_open_positions_per_symbol == 1
|
|
|
|
|
|
def test_portfolio_config_roundtrip_json():
|
|
"""Config should survive serialization to JSON and back."""
|
|
cfg = PortfolioRiskConfig(
|
|
name="test-profile",
|
|
trading_mode=TradingMode.LIVE,
|
|
position_limits=PositionLimits(max_position_pct=0.10),
|
|
daily_loss=DailyLossLimits(max_daily_trades=5),
|
|
)
|
|
data = cfg.to_db_json()
|
|
restored = PortfolioRiskConfig.from_db_json(data)
|
|
|
|
assert restored.name == "test-profile"
|
|
assert restored.trading_mode == TradingMode.LIVE
|
|
assert restored.position_limits.max_position_pct == 0.10
|
|
assert restored.daily_loss.max_daily_trades == 5
|
|
# Nested defaults should survive
|
|
assert restored.sector_exposure.max_sector_pct == 0.25
|
|
assert restored.news_shock.enabled is True
|
|
|
|
|
|
def test_account_risk_state_defaults():
|
|
state = AccountRiskState(account_id="test-acct")
|
|
assert state.portfolio_value == 0.0
|
|
assert state.daily_trade_count == 0
|
|
assert state.positions_by_symbol == {}
|
|
assert state.positions_by_sector == {}
|
|
assert state.locked_symbols == {}
|
|
|
|
|
|
def test_account_risk_state_with_positions():
|
|
state = AccountRiskState(
|
|
account_id="acct-1",
|
|
portfolio_value=100_000.0,
|
|
cash=50_000.0,
|
|
daily_pnl=-500.0,
|
|
daily_trade_count=3,
|
|
positions_by_symbol={"AAPL": 10_000.0, "MSFT": 5_000.0},
|
|
positions_by_sector={"Technology": 15_000.0},
|
|
)
|
|
assert state.positions_by_symbol["AAPL"] == 10_000.0
|
|
assert state.positions_by_sector["Technology"] == 15_000.0
|
|
assert state.daily_pnl == -500.0
|
|
|
|
|
|
def test_risk_evaluation_passed_property():
|
|
"""passed should be True only when eligible and no rejections."""
|
|
passing = RiskEvaluation(
|
|
ticker="AAPL",
|
|
eligible=True,
|
|
allowed_mode=TradingMode.PAPER,
|
|
checks=[
|
|
RiskCheckDetail(check_name="position_size", result=RiskCheckResult.PASS),
|
|
],
|
|
)
|
|
assert passing.passed is True
|
|
|
|
failing = RiskEvaluation(
|
|
ticker="AAPL",
|
|
eligible=False,
|
|
allowed_mode=TradingMode.DISABLED,
|
|
rejection_reasons=["max_daily_loss_exceeded"],
|
|
checks=[
|
|
RiskCheckDetail(
|
|
check_name="daily_loss",
|
|
result=RiskCheckResult.FAIL,
|
|
message="Daily loss limit exceeded",
|
|
threshold=0.02,
|
|
actual=0.03,
|
|
),
|
|
],
|
|
)
|
|
assert failing.passed is False
|
|
|
|
|
|
def test_risk_evaluation_captures_config_snapshot():
|
|
"""Evaluation should be able to store the config used for reproducibility."""
|
|
cfg = PortfolioRiskConfig(name="snapshot-test")
|
|
state = AccountRiskState(account_id="acct-1", portfolio_value=50_000.0)
|
|
|
|
evaluation = RiskEvaluation(
|
|
ticker="TSLA",
|
|
eligible=True,
|
|
allowed_mode=TradingMode.PAPER,
|
|
config_snapshot=cfg,
|
|
state_snapshot=state,
|
|
)
|
|
assert evaluation.config_snapshot is not None
|
|
assert evaluation.config_snapshot.name == "snapshot-test"
|
|
assert evaluation.state_snapshot is not None
|
|
assert evaluation.state_snapshot.portfolio_value == 50_000.0
|
|
|
|
|
|
def test_trading_mode_disabled():
|
|
"""DISABLED mode should be available for halting all trading."""
|
|
cfg = PortfolioRiskConfig(trading_mode=TradingMode.DISABLED)
|
|
assert cfg.trading_mode == TradingMode.DISABLED
|
|
|
|
|
|
def test_default_risk_config_singleton():
|
|
"""Module-level default should be a valid paper config."""
|
|
assert DEFAULT_RISK_CONFIG.trading_mode == TradingMode.PAPER
|
|
assert DEFAULT_RISK_CONFIG.name == "default"
|
|
|
|
|
|
# ===================================================================
|
|
# Enforcement logic tests (hard blocks)
|
|
# ===================================================================
|
|
|
|
|
|
def _make_config(**overrides) -> PortfolioRiskConfig:
|
|
return PortfolioRiskConfig(
|
|
trading_mode=overrides.get("trading_mode", TradingMode.PAPER),
|
|
position_limits=overrides.get("position_limits", PositionLimits()),
|
|
sector_exposure=overrides.get("sector_exposure", SectorExposureLimits()),
|
|
daily_loss=overrides.get("daily_loss", DailyLossLimits()),
|
|
news_shock=overrides.get("news_shock", NewsShockLockout()),
|
|
symbol_cooldown=overrides.get("symbol_cooldown", SymbolCooldown()),
|
|
)
|
|
|
|
|
|
def _make_state(**overrides) -> AccountRiskState:
|
|
return AccountRiskState(
|
|
account_id=overrides.get("account_id", "test-acct"),
|
|
portfolio_value=overrides.get("portfolio_value", 100_000.0),
|
|
cash=overrides.get("cash", 50_000.0),
|
|
daily_pnl=overrides.get("daily_pnl", 0.0),
|
|
daily_trade_count=overrides.get("daily_trade_count", 0),
|
|
positions_by_symbol=overrides.get("positions_by_symbol", {}),
|
|
positions_by_sector=overrides.get("positions_by_sector", {}),
|
|
last_trade_times=overrides.get("last_trade_times", {}),
|
|
locked_symbols=overrides.get("locked_symbols", {}),
|
|
)
|
|
|
|
|
|
# --- Trading mode gate ---
|
|
|
|
|
|
def test_evaluate_order_disabled_mode_blocks():
|
|
"""Orders are rejected when trading mode is DISABLED."""
|
|
config = _make_config(trading_mode=TradingMode.DISABLED)
|
|
order = ProposedOrder(ticker="AAPL", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, _make_state())
|
|
assert result.passed is False
|
|
assert any("disabled" in r.lower() for r in result.rejection_reasons)
|
|
|
|
|
|
def test_evaluate_order_paper_mode_passes():
|
|
"""A clean order in paper mode should pass all checks."""
|
|
config = _make_config()
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, _make_state())
|
|
assert result.passed is True
|
|
assert result.allowed_mode == TradingMode.PAPER
|
|
|
|
|
|
# --- Max position size ---
|
|
|
|
|
|
def test_position_value_exceeded():
|
|
config = _make_config(position_limits=PositionLimits(max_position_value=5000))
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10)
|
|
result = evaluate_order(order, config, _make_state())
|
|
assert result.passed is False
|
|
assert any(c.check_name == "max_position_value" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
def test_position_value_includes_existing():
|
|
"""Existing position value is added to the new order value."""
|
|
config = _make_config(position_limits=PositionLimits(max_position_value=5000))
|
|
state = _make_state(positions_by_symbol={"AAPL": 3000.0})
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=2500, quantity=5)
|
|
result = evaluate_order(order, config, state)
|
|
assert result.passed is False
|
|
fail_check = next(c for c in result.checks if c.check_name == "max_position_value")
|
|
assert fail_check.actual == 5500.0
|
|
|
|
|
|
def test_position_pct_exceeded():
|
|
config = _make_config(position_limits=PositionLimits(max_position_pct=0.05))
|
|
state = _make_state(portfolio_value=100_000)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
assert any(c.check_name == "max_position_pct" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
def test_sell_on_over_limit_position_allowed():
|
|
"""Selling an over-concentrated position should pass risk checks."""
|
|
config = _make_config(position_limits=PositionLimits(max_position_pct=0.05))
|
|
state = _make_state(
|
|
portfolio_value=100_000,
|
|
positions_by_symbol={"AVGO": 5200.0}, # 5.2% — over the 5% limit
|
|
)
|
|
order = ProposedOrder(
|
|
ticker="AVGO", sector="Technology", action="sell",
|
|
estimated_value=5200.0, quantity=13,
|
|
)
|
|
result = evaluate_order(order, config, state)
|
|
pct_check = next(c for c in result.checks if c.check_name == "max_position_pct")
|
|
assert pct_check.result == RiskCheckResult.PASS, (
|
|
f"Sell on over-limit position should pass, got: {pct_check.message}"
|
|
)
|
|
|
|
|
|
def test_max_shares_exceeded():
|
|
config = _make_config(position_limits=PositionLimits(max_shares_per_order=100))
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=200)
|
|
result = evaluate_order(order, config, _make_state())
|
|
assert any(c.check_name == "max_shares_per_order" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
# --- Sector exposure ---
|
|
|
|
|
|
def test_sector_exposure_exceeded():
|
|
config = _make_config(sector_exposure=SectorExposureLimits(max_sector_pct=0.25))
|
|
state = _make_state(
|
|
portfolio_value=100_000,
|
|
positions_by_sector={"Technology": 20_000.0},
|
|
)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
assert any(c.check_name == "sector_exposure" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
def test_sector_exposure_no_sector_warns():
|
|
"""Missing sector on order produces a warning, not a failure."""
|
|
config = _make_config()
|
|
order = ProposedOrder(ticker="AAPL", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, _make_state())
|
|
sector_check = next(c for c in result.checks if c.check_name == "sector_exposure")
|
|
assert sector_check.result == RiskCheckResult.WARN
|
|
|
|
|
|
# --- Daily loss limits ---
|
|
|
|
|
|
def test_daily_loss_pct_exceeded():
|
|
config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_pct=0.02))
|
|
state = _make_state(portfolio_value=100_000, daily_pnl=-2500)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
assert any(c.check_name == "daily_loss_pct" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
def test_daily_loss_value_exceeded():
|
|
config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_value=500))
|
|
state = _make_state(daily_pnl=-600)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
assert any(c.check_name == "daily_loss_value" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
def test_daily_trade_count_exceeded():
|
|
config = _make_config(daily_loss=DailyLossLimits(max_daily_trades=5))
|
|
state = _make_state(daily_trade_count=5)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
assert any(c.check_name == "daily_trade_count" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
def test_positive_pnl_does_not_trigger_loss_limit():
|
|
"""Positive P&L should not trigger daily loss checks."""
|
|
config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_pct=0.02))
|
|
state = _make_state(portfolio_value=100_000, daily_pnl=5000)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
loss_checks = [c for c in result.checks if c.check_name.startswith("daily_loss")]
|
|
assert all(c.result == RiskCheckResult.PASS for c in loss_checks)
|
|
|
|
|
|
# --- News-shock lockout ---
|
|
|
|
|
|
def test_news_shock_lockout_blocks():
|
|
now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc)
|
|
lockout_expiry = now + timedelta(minutes=30)
|
|
state = _make_state(locked_symbols={"AAPL": lockout_expiry})
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, _make_config(), state, now=now)
|
|
assert any(c.check_name == "news_shock_lockout" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
def test_news_shock_lockout_expired_passes():
|
|
now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc)
|
|
lockout_expiry = now - timedelta(minutes=5) # already expired
|
|
state = _make_state(locked_symbols={"AAPL": lockout_expiry})
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, _make_config(), state, now=now)
|
|
lockout_check = next(c for c in result.checks if c.check_name == "news_shock_lockout")
|
|
assert lockout_check.result == RiskCheckResult.PASS
|
|
|
|
|
|
def test_news_shock_lockout_disabled_passes():
|
|
now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc)
|
|
lockout_expiry = now + timedelta(minutes=30)
|
|
config = _make_config(news_shock=NewsShockLockout(enabled=False))
|
|
state = _make_state(locked_symbols={"AAPL": lockout_expiry})
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, state, now=now)
|
|
lockout_check = next(c for c in result.checks if c.check_name == "news_shock_lockout")
|
|
assert lockout_check.result == RiskCheckResult.PASS
|
|
|
|
|
|
# --- Symbol cooldown ---
|
|
|
|
|
|
def test_symbol_cooldown_blocks():
|
|
now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc)
|
|
last_trade = now - timedelta(minutes=5) # 5 min ago, default cooldown is 15
|
|
state = _make_state(last_trade_times={"AAPL": last_trade})
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, _make_config(), state, now=now)
|
|
assert any(c.check_name == "symbol_cooldown" and c.result == RiskCheckResult.FAIL for c in result.checks)
|
|
|
|
|
|
def test_symbol_cooldown_expired_passes():
|
|
now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc)
|
|
last_trade = now - timedelta(minutes=20) # 20 min ago, cooldown is 15
|
|
state = _make_state(last_trade_times={"AAPL": last_trade})
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, _make_config(), state, now=now)
|
|
cooldown_check = next(c for c in result.checks if c.check_name == "symbol_cooldown")
|
|
assert cooldown_check.result == RiskCheckResult.PASS
|
|
|
|
|
|
# --- Combined scenarios ---
|
|
|
|
|
|
def test_multiple_failures_all_captured():
|
|
"""When multiple checks fail, all rejection reasons are captured."""
|
|
config = _make_config(
|
|
position_limits=PositionLimits(max_position_value=500),
|
|
daily_loss=DailyLossLimits(max_daily_loss_value=100),
|
|
)
|
|
state = _make_state(daily_pnl=-200)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
assert result.passed is False
|
|
assert len(result.rejection_reasons) >= 2
|
|
|
|
|
|
def test_evaluation_captures_snapshots():
|
|
"""Config and state snapshots are stored for reproducibility."""
|
|
config = _make_config()
|
|
state = _make_state(portfolio_value=75_000)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
assert result.config_snapshot is not None
|
|
assert result.state_snapshot is not None
|
|
assert result.state_snapshot.portfolio_value == 75_000
|
|
|
|
|
|
def test_fail_closed_no_state():
|
|
"""With zero portfolio value, position pct check should fail-closed for non-zero orders."""
|
|
config = _make_config()
|
|
state = _make_state(portfolio_value=0.0)
|
|
order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10)
|
|
result = evaluate_order(order, config, state)
|
|
# position_pct = 1.0 when portfolio is 0 and order value > 0 → exceeds 0.05
|
|
assert any(c.check_name == "max_position_pct" and c.result == RiskCheckResult.FAIL for c in result.checks)
|