"""Tests for the portfolio and account risk configuration model and enforcement.""" from datetime import datetime, timedelta, timezone from services.risk.engine import ( DEFAULT_RISK_CONFIG, AccountRiskState, DailyLossLimits, NewsShockLockout, OperatorApproval, PortfolioRiskConfig, PositionLimits, ProposedOrder, RiskCheckDetail, RiskCheckResult, RiskEvaluation, SectorExposureLimits, SymbolCooldown, TradingMode, evaluate_order, ) def test_default_risk_config_is_paper_mode(): """Default config should be paper trading mode.""" cfg = PortfolioRiskConfig() assert cfg.trading_mode == TradingMode.PAPER assert cfg.active is True def test_position_limits_defaults(): limits = PositionLimits() assert limits.max_position_pct == 0.05 assert limits.max_position_value == 10_000.0 assert limits.max_shares_per_order == 1000.0 def test_sector_exposure_defaults(): limits = SectorExposureLimits() assert limits.max_sector_pct == 0.25 assert limits.max_sectors == 10 def test_daily_loss_defaults(): limits = DailyLossLimits() assert limits.max_daily_loss_pct == 0.02 assert limits.max_daily_loss_value == 1_000.0 assert limits.max_daily_trades == 20 def test_news_shock_lockout_defaults(): lockout = NewsShockLockout() assert lockout.enabled is True assert lockout.lockout_minutes == 60 assert lockout.impact_threshold == 0.80 assert "earnings" in lockout.catalyst_types def test_operator_approval_defaults(): approval = OperatorApproval() assert approval.require_approval_for_live is True assert approval.auto_approve_paper is True assert approval.approval_timeout_minutes == 30 def test_symbol_cooldown_defaults(): cooldown = SymbolCooldown() assert cooldown.cooldown_minutes == 15 assert cooldown.max_open_positions_per_symbol == 1 def test_portfolio_config_roundtrip_json(): """Config should survive serialization to JSON and back.""" cfg = PortfolioRiskConfig( name="test-profile", trading_mode=TradingMode.LIVE, position_limits=PositionLimits(max_position_pct=0.10), daily_loss=DailyLossLimits(max_daily_trades=5), ) data = cfg.to_db_json() restored = PortfolioRiskConfig.from_db_json(data) assert restored.name == "test-profile" assert restored.trading_mode == TradingMode.LIVE assert restored.position_limits.max_position_pct == 0.10 assert restored.daily_loss.max_daily_trades == 5 # Nested defaults should survive assert restored.sector_exposure.max_sector_pct == 0.25 assert restored.news_shock.enabled is True def test_account_risk_state_defaults(): state = AccountRiskState(account_id="test-acct") assert state.portfolio_value == 0.0 assert state.daily_trade_count == 0 assert state.positions_by_symbol == {} assert state.positions_by_sector == {} assert state.locked_symbols == {} def test_account_risk_state_with_positions(): state = AccountRiskState( account_id="acct-1", portfolio_value=100_000.0, cash=50_000.0, daily_pnl=-500.0, daily_trade_count=3, positions_by_symbol={"AAPL": 10_000.0, "MSFT": 5_000.0}, positions_by_sector={"Technology": 15_000.0}, ) assert state.positions_by_symbol["AAPL"] == 10_000.0 assert state.positions_by_sector["Technology"] == 15_000.0 assert state.daily_pnl == -500.0 def test_risk_evaluation_passed_property(): """passed should be True only when eligible and no rejections.""" passing = RiskEvaluation( ticker="AAPL", eligible=True, allowed_mode=TradingMode.PAPER, checks=[ RiskCheckDetail(check_name="position_size", result=RiskCheckResult.PASS), ], ) assert passing.passed is True failing = RiskEvaluation( ticker="AAPL", eligible=False, allowed_mode=TradingMode.DISABLED, rejection_reasons=["max_daily_loss_exceeded"], checks=[ RiskCheckDetail( check_name="daily_loss", result=RiskCheckResult.FAIL, message="Daily loss limit exceeded", threshold=0.02, actual=0.03, ), ], ) assert failing.passed is False def test_risk_evaluation_captures_config_snapshot(): """Evaluation should be able to store the config used for reproducibility.""" cfg = PortfolioRiskConfig(name="snapshot-test") state = AccountRiskState(account_id="acct-1", portfolio_value=50_000.0) evaluation = RiskEvaluation( ticker="TSLA", eligible=True, allowed_mode=TradingMode.PAPER, config_snapshot=cfg, state_snapshot=state, ) assert evaluation.config_snapshot is not None assert evaluation.config_snapshot.name == "snapshot-test" assert evaluation.state_snapshot is not None assert evaluation.state_snapshot.portfolio_value == 50_000.0 def test_trading_mode_disabled(): """DISABLED mode should be available for halting all trading.""" cfg = PortfolioRiskConfig(trading_mode=TradingMode.DISABLED) assert cfg.trading_mode == TradingMode.DISABLED def test_default_risk_config_singleton(): """Module-level default should be a valid paper config.""" assert DEFAULT_RISK_CONFIG.trading_mode == TradingMode.PAPER assert DEFAULT_RISK_CONFIG.name == "default" # =================================================================== # Enforcement logic tests (hard blocks) # =================================================================== def _make_config(**overrides) -> PortfolioRiskConfig: return PortfolioRiskConfig( trading_mode=overrides.get("trading_mode", TradingMode.PAPER), position_limits=overrides.get("position_limits", PositionLimits()), sector_exposure=overrides.get("sector_exposure", SectorExposureLimits()), daily_loss=overrides.get("daily_loss", DailyLossLimits()), news_shock=overrides.get("news_shock", NewsShockLockout()), symbol_cooldown=overrides.get("symbol_cooldown", SymbolCooldown()), ) def _make_state(**overrides) -> AccountRiskState: return AccountRiskState( account_id=overrides.get("account_id", "test-acct"), portfolio_value=overrides.get("portfolio_value", 100_000.0), cash=overrides.get("cash", 50_000.0), daily_pnl=overrides.get("daily_pnl", 0.0), daily_trade_count=overrides.get("daily_trade_count", 0), positions_by_symbol=overrides.get("positions_by_symbol", {}), positions_by_sector=overrides.get("positions_by_sector", {}), last_trade_times=overrides.get("last_trade_times", {}), locked_symbols=overrides.get("locked_symbols", {}), ) # --- Trading mode gate --- def test_evaluate_order_disabled_mode_blocks(): """Orders are rejected when trading mode is DISABLED.""" config = _make_config(trading_mode=TradingMode.DISABLED) order = ProposedOrder(ticker="AAPL", estimated_value=1000, quantity=10) result = evaluate_order(order, config, _make_state()) assert result.passed is False assert any("disabled" in r.lower() for r in result.rejection_reasons) def test_evaluate_order_paper_mode_passes(): """A clean order in paper mode should pass all checks.""" config = _make_config() order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, _make_state()) assert result.passed is True assert result.allowed_mode == TradingMode.PAPER # --- Max position size --- def test_position_value_exceeded(): config = _make_config(position_limits=PositionLimits(max_position_value=5000)) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10) result = evaluate_order(order, config, _make_state()) assert result.passed is False assert any(c.check_name == "max_position_value" and c.result == RiskCheckResult.FAIL for c in result.checks) def test_position_value_includes_existing(): """Existing position value is added to the new order value.""" config = _make_config(position_limits=PositionLimits(max_position_value=5000)) state = _make_state(positions_by_symbol={"AAPL": 3000.0}) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=2500, quantity=5) result = evaluate_order(order, config, state) assert result.passed is False fail_check = next(c for c in result.checks if c.check_name == "max_position_value") assert fail_check.actual == 5500.0 def test_position_pct_exceeded(): config = _make_config(position_limits=PositionLimits(max_position_pct=0.05)) state = _make_state(portfolio_value=100_000) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10) result = evaluate_order(order, config, state) assert any(c.check_name == "max_position_pct" and c.result == RiskCheckResult.FAIL for c in result.checks) def test_max_shares_exceeded(): config = _make_config(position_limits=PositionLimits(max_shares_per_order=100)) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=200) result = evaluate_order(order, config, _make_state()) assert any(c.check_name == "max_shares_per_order" and c.result == RiskCheckResult.FAIL for c in result.checks) # --- Sector exposure --- def test_sector_exposure_exceeded(): config = _make_config(sector_exposure=SectorExposureLimits(max_sector_pct=0.25)) state = _make_state( portfolio_value=100_000, positions_by_sector={"Technology": 20_000.0}, ) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=6000, quantity=10) result = evaluate_order(order, config, state) assert any(c.check_name == "sector_exposure" and c.result == RiskCheckResult.FAIL for c in result.checks) def test_sector_exposure_no_sector_warns(): """Missing sector on order produces a warning, not a failure.""" config = _make_config() order = ProposedOrder(ticker="AAPL", estimated_value=1000, quantity=10) result = evaluate_order(order, config, _make_state()) sector_check = next(c for c in result.checks if c.check_name == "sector_exposure") assert sector_check.result == RiskCheckResult.WARN # --- Daily loss limits --- def test_daily_loss_pct_exceeded(): config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_pct=0.02)) state = _make_state(portfolio_value=100_000, daily_pnl=-2500) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, state) assert any(c.check_name == "daily_loss_pct" and c.result == RiskCheckResult.FAIL for c in result.checks) def test_daily_loss_value_exceeded(): config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_value=500)) state = _make_state(daily_pnl=-600) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, state) assert any(c.check_name == "daily_loss_value" and c.result == RiskCheckResult.FAIL for c in result.checks) def test_daily_trade_count_exceeded(): config = _make_config(daily_loss=DailyLossLimits(max_daily_trades=5)) state = _make_state(daily_trade_count=5) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, state) assert any(c.check_name == "daily_trade_count" and c.result == RiskCheckResult.FAIL for c in result.checks) def test_positive_pnl_does_not_trigger_loss_limit(): """Positive P&L should not trigger daily loss checks.""" config = _make_config(daily_loss=DailyLossLimits(max_daily_loss_pct=0.02)) state = _make_state(portfolio_value=100_000, daily_pnl=5000) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, state) loss_checks = [c for c in result.checks if c.check_name.startswith("daily_loss")] assert all(c.result == RiskCheckResult.PASS for c in loss_checks) # --- News-shock lockout --- def test_news_shock_lockout_blocks(): now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) lockout_expiry = now + timedelta(minutes=30) state = _make_state(locked_symbols={"AAPL": lockout_expiry}) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, _make_config(), state, now=now) assert any(c.check_name == "news_shock_lockout" and c.result == RiskCheckResult.FAIL for c in result.checks) def test_news_shock_lockout_expired_passes(): now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) lockout_expiry = now - timedelta(minutes=5) # already expired state = _make_state(locked_symbols={"AAPL": lockout_expiry}) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, _make_config(), state, now=now) lockout_check = next(c for c in result.checks if c.check_name == "news_shock_lockout") assert lockout_check.result == RiskCheckResult.PASS def test_news_shock_lockout_disabled_passes(): now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) lockout_expiry = now + timedelta(minutes=30) config = _make_config(news_shock=NewsShockLockout(enabled=False)) state = _make_state(locked_symbols={"AAPL": lockout_expiry}) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, state, now=now) lockout_check = next(c for c in result.checks if c.check_name == "news_shock_lockout") assert lockout_check.result == RiskCheckResult.PASS # --- Symbol cooldown --- def test_symbol_cooldown_blocks(): now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) last_trade = now - timedelta(minutes=5) # 5 min ago, default cooldown is 15 state = _make_state(last_trade_times={"AAPL": last_trade}) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, _make_config(), state, now=now) assert any(c.check_name == "symbol_cooldown" and c.result == RiskCheckResult.FAIL for c in result.checks) def test_symbol_cooldown_expired_passes(): now = datetime(2026, 4, 11, 12, 0, 0, tzinfo=timezone.utc) last_trade = now - timedelta(minutes=20) # 20 min ago, cooldown is 15 state = _make_state(last_trade_times={"AAPL": last_trade}) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, _make_config(), state, now=now) cooldown_check = next(c for c in result.checks if c.check_name == "symbol_cooldown") assert cooldown_check.result == RiskCheckResult.PASS # --- Combined scenarios --- def test_multiple_failures_all_captured(): """When multiple checks fail, all rejection reasons are captured.""" config = _make_config( position_limits=PositionLimits(max_position_value=500), daily_loss=DailyLossLimits(max_daily_loss_value=100), ) state = _make_state(daily_pnl=-200) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, state) assert result.passed is False assert len(result.rejection_reasons) >= 2 def test_evaluation_captures_snapshots(): """Config and state snapshots are stored for reproducibility.""" config = _make_config() state = _make_state(portfolio_value=75_000) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, state) assert result.config_snapshot is not None assert result.state_snapshot is not None assert result.state_snapshot.portfolio_value == 75_000 def test_fail_closed_no_state(): """With zero portfolio value, position pct check should fail-closed for non-zero orders.""" config = _make_config() state = _make_state(portfolio_value=0.0) order = ProposedOrder(ticker="AAPL", sector="Technology", estimated_value=1000, quantity=10) result = evaluate_order(order, config, state) # position_pct = 1.0 when portfolio is 0 and order value > 0 → exceeds 0.05 assert any(c.check_name == "max_position_pct" and c.result == RiskCheckResult.FAIL for c in result.checks)