feat: override trade tab — manual order entry with auto-registration
Backend: - OverrideOrderRequest/Response Pydantic models with ticker, quantity, price validators - POST /api/trading/override/order endpoint (enqueue to Redis broker queue) - auto_register_symbol() module for untracked ticker registration via Symbol Registry - Unit tests (17) and property-based tests (3 x 100 examples) Frontend: - OverrideTradePanel component (order form + positions display) - Override tab in TradingEngine page with URL search param navigation - Override Trade button on Trading Controls page - useSubmitOverrideOrder mutation hook - MSW handler and 13 component/integration tests Steering: - Updated steering docs for Ubuntu dev machine with nvm/Node 24
This commit is contained in:
@@ -0,0 +1,812 @@
|
||||
"""Unit tests for auto_register_symbol() in services/trading/override.py.
|
||||
|
||||
Validates:
|
||||
- New symbol registration (company + sources + watchlist)
|
||||
- Existing symbol skip
|
||||
- 409 conflict handling
|
||||
- Source/watchlist failure tolerance
|
||||
|
||||
Requirements: 4.1, 4.2, 4.3, 4.5, 4.6
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from services.trading.override import auto_register_symbol
|
||||
|
||||
REGISTRY_BASE = "http://symbol-registry:8000"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _json_response(status_code: int, json_data) -> httpx.Response:
|
||||
"""Build an httpx.Response with JSON body for MockTransport."""
|
||||
return httpx.Response(
|
||||
status_code,
|
||||
content=json.dumps(json_data).encode(),
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
|
||||
|
||||
def _text_response(status_code: int, text: str = "") -> httpx.Response:
|
||||
"""Build an httpx.Response with text body for MockTransport."""
|
||||
return httpx.Response(status_code, text=text)
|
||||
|
||||
|
||||
def _patched_client(handler):
|
||||
"""Return a patch context manager that replaces httpx.AsyncClient with a mock-transport client.
|
||||
|
||||
The handler is a sync callable: ``(request: httpx.Request) -> httpx.Response``.
|
||||
"""
|
||||
original_init = httpx.AsyncClient.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
kwargs.pop("timeout", None)
|
||||
kwargs["transport"] = httpx.MockTransport(handler)
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
return patch.object(httpx.AsyncClient, "__init__", patched_init)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. New symbol registration — full happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoRegisterNewSymbol:
|
||||
"""Requirement 4.1, 4.2, 4.3: New symbol creates company, sources, watchlist membership."""
|
||||
|
||||
async def test_new_symbol_full_registration(self):
|
||||
"""GET /companies returns empty → POST /companies 201 → sources → watchlist → (True, id)."""
|
||||
company_id = "aaaaaaaa-1111-2222-3333-444444444444"
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
method = request.method
|
||||
|
||||
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
||||
return _json_response(200, [])
|
||||
|
||||
if method == "POST" and url.endswith("/companies"):
|
||||
return _json_response(201, {"id": company_id, "ticker": "TSLA"})
|
||||
|
||||
if method == "POST" and "/sources" in url:
|
||||
return _json_response(201, {"id": "src-1"})
|
||||
|
||||
if method == "GET" and "/watchlists" in url:
|
||||
return _json_response(200, [{"id": "wl-1", "name": "Default", "active": True}])
|
||||
|
||||
if method == "POST" and "/members/" in url:
|
||||
return _json_response(201, {})
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("TSLA", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is True
|
||||
assert cid == company_id
|
||||
|
||||
async def test_new_symbol_creates_watchlist_when_none_exist(self):
|
||||
"""When no active watchlists exist, creates 'Manual Overrides' watchlist."""
|
||||
company_id = "bbbbbbbb-1111-2222-3333-444444444444"
|
||||
watchlist_id = "wl-new-1"
|
||||
call_log: list[str] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
method = request.method
|
||||
call_log.append(f"{method} {url}")
|
||||
|
||||
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
||||
return _json_response(200, [])
|
||||
|
||||
if method == "POST" and url.endswith("/companies"):
|
||||
return _json_response(201, {"id": company_id, "ticker": "NVDA"})
|
||||
|
||||
if method == "POST" and "/sources" in url:
|
||||
return _json_response(201, {"id": "src-1"})
|
||||
|
||||
if method == "GET" and "/watchlists" in url:
|
||||
return _json_response(200, []) # No active watchlists
|
||||
|
||||
if method == "POST" and url.endswith("/watchlists"):
|
||||
return _json_response(201, {"id": watchlist_id, "name": "Manual Overrides"})
|
||||
|
||||
if method == "POST" and "/members/" in url:
|
||||
return _json_response(201, {})
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("NVDA", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is True
|
||||
assert cid == company_id
|
||||
assert any("POST" in c and "/watchlists" in c and "/members" not in c for c in call_log)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Existing symbol skip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoRegisterExistingSymbol:
|
||||
"""Requirement 4.5: Existing symbols skip registration."""
|
||||
|
||||
async def test_existing_symbol_returns_false(self):
|
||||
"""GET /companies returns list containing the ticker → (False, company_id)."""
|
||||
existing_id = "cccccccc-1111-2222-3333-444444444444"
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
if request.method == "GET" and "/companies" in url:
|
||||
return _json_response(200, [
|
||||
{"id": existing_id, "ticker": "AAPL", "legal_name": "Apple Inc"},
|
||||
{"id": "other-id", "ticker": "MSFT", "legal_name": "Microsoft"},
|
||||
])
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("AAPL", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is False
|
||||
assert cid == existing_id
|
||||
|
||||
async def test_existing_symbol_no_post_calls(self):
|
||||
"""When ticker exists, no POST calls should be made."""
|
||||
existing_id = "dddddddd-1111-2222-3333-444444444444"
|
||||
post_calls: list[str] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
if request.method == "POST":
|
||||
post_calls.append(url)
|
||||
if request.method == "GET" and "/companies" in url:
|
||||
return _json_response(200, [{"id": existing_id, "ticker": "GOOG"}])
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
await auto_register_symbol("GOOG", REGISTRY_BASE)
|
||||
|
||||
assert post_calls == [], f"Unexpected POST calls: {post_calls}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. 409 conflict handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAutoRegister409Conflict:
|
||||
"""Requirement 4.6: 409 conflict treated as success."""
|
||||
|
||||
async def test_409_conflict_fetches_existing_company(self):
|
||||
"""POST /companies 409 → fetches existing company → returns with company_id."""
|
||||
existing_id = "eeeeeeee-1111-2222-3333-444444444444"
|
||||
get_call_count = 0
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
nonlocal get_call_count
|
||||
url = str(request.url)
|
||||
method = request.method
|
||||
|
||||
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
||||
get_call_count += 1
|
||||
if get_call_count == 1:
|
||||
return _json_response(200, [])
|
||||
else:
|
||||
return _json_response(200, [{"id": existing_id, "ticker": "AMD"}])
|
||||
|
||||
if method == "POST" and url.endswith("/companies"):
|
||||
return _json_response(409, {"detail": "Company already exists"})
|
||||
|
||||
if method == "POST" and "/sources" in url:
|
||||
return _json_response(201, {"id": "src-1"})
|
||||
|
||||
if method == "GET" and "/watchlists" in url:
|
||||
return _json_response(200, [{"id": "wl-1", "active": True}])
|
||||
|
||||
if method == "POST" and "/members/" in url:
|
||||
return _json_response(201, {})
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("AMD", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is True
|
||||
assert cid == existing_id
|
||||
|
||||
async def test_409_conflict_but_cannot_find_company_returns_false(self):
|
||||
"""POST /companies 409 but re-fetch can't find company → (False, '')."""
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
|
||||
if request.method == "GET" and "/companies" in url:
|
||||
return _json_response(200, [])
|
||||
|
||||
if request.method == "POST" and url.endswith("/companies"):
|
||||
return _json_response(409, {"detail": "Conflict"})
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("XYZ", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is False
|
||||
assert cid == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Source creation failure tolerance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSourceFailureTolerance:
|
||||
"""Source creation failures are best-effort — should not block registration."""
|
||||
|
||||
async def test_source_creation_500_still_succeeds(self):
|
||||
"""POST /companies/{id}/sources returns 500 → still returns (True, company_id)."""
|
||||
company_id = "ffffffff-1111-2222-3333-444444444444"
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
method = request.method
|
||||
|
||||
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
||||
return _json_response(200, [])
|
||||
|
||||
if method == "POST" and url.endswith("/companies"):
|
||||
return _json_response(201, {"id": company_id, "ticker": "PLTR"})
|
||||
|
||||
if method == "POST" and "/sources" in url:
|
||||
return _text_response(500, "Internal Server Error")
|
||||
|
||||
if method == "GET" and "/watchlists" in url:
|
||||
return _json_response(200, [{"id": "wl-1", "active": True}])
|
||||
|
||||
if method == "POST" and "/members/" in url:
|
||||
return _json_response(201, {})
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("PLTR", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is True
|
||||
assert cid == company_id
|
||||
|
||||
async def test_source_creation_network_error_still_succeeds(self):
|
||||
"""Source creation raises network error → still returns successfully."""
|
||||
company_id = "11111111-aaaa-bbbb-cccc-dddddddddddd"
|
||||
source_call_count = 0
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
nonlocal source_call_count
|
||||
url = str(request.url)
|
||||
method = request.method
|
||||
|
||||
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
||||
return _json_response(200, [])
|
||||
|
||||
if method == "POST" and url.endswith("/companies"):
|
||||
return _json_response(201, {"id": company_id, "ticker": "RIVN"})
|
||||
|
||||
if method == "POST" and "/sources" in url:
|
||||
source_call_count += 1
|
||||
raise httpx.ConnectError("connection refused")
|
||||
|
||||
if method == "GET" and "/watchlists" in url:
|
||||
return _json_response(200, [{"id": "wl-1", "active": True}])
|
||||
|
||||
if method == "POST" and "/members/" in url:
|
||||
return _json_response(201, {})
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("RIVN", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is True
|
||||
assert cid == company_id
|
||||
assert source_call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Watchlist failure tolerance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestWatchlistFailureTolerance:
|
||||
"""Watchlist failures are best-effort — should not block registration."""
|
||||
|
||||
async def test_watchlist_get_fails_still_succeeds(self):
|
||||
"""GET /watchlists raises exception → still returns (True, company_id)."""
|
||||
company_id = "22222222-aaaa-bbbb-cccc-dddddddddddd"
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
method = request.method
|
||||
|
||||
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
||||
return _json_response(200, [])
|
||||
|
||||
if method == "POST" and url.endswith("/companies"):
|
||||
return _json_response(201, {"id": company_id, "ticker": "SOFI"})
|
||||
|
||||
if method == "POST" and "/sources" in url:
|
||||
return _json_response(201, {"id": "src-1"})
|
||||
|
||||
if method == "GET" and "/watchlists" in url:
|
||||
raise httpx.ConnectError("connection refused")
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("SOFI", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is True
|
||||
assert cid == company_id
|
||||
|
||||
async def test_watchlist_member_add_fails_still_succeeds(self):
|
||||
"""POST /watchlists/{id}/members/{cid} returns 500 → still succeeds."""
|
||||
company_id = "33333333-aaaa-bbbb-cccc-dddddddddddd"
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
method = request.method
|
||||
|
||||
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
||||
return _json_response(200, [])
|
||||
|
||||
if method == "POST" and url.endswith("/companies"):
|
||||
return _json_response(201, {"id": company_id, "ticker": "COIN"})
|
||||
|
||||
if method == "POST" and "/sources" in url:
|
||||
return _json_response(201, {"id": "src-1"})
|
||||
|
||||
if method == "GET" and "/watchlists" in url:
|
||||
return _json_response(200, [{"id": "wl-1", "active": True}])
|
||||
|
||||
if method == "POST" and "/members/" in url:
|
||||
return _text_response(500, "Internal Server Error")
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("COIN", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is True
|
||||
assert cid == company_id
|
||||
|
||||
async def test_ticker_is_uppercased(self):
|
||||
"""Lowercase ticker input is normalized to uppercase."""
|
||||
company_id = "44444444-aaaa-bbbb-cccc-dddddddddddd"
|
||||
received_tickers: list[str] = []
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
url = str(request.url)
|
||||
method = request.method
|
||||
|
||||
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
||||
return _json_response(200, [])
|
||||
|
||||
if method == "POST" and url.endswith("/companies"):
|
||||
body = json.loads(request.content)
|
||||
received_tickers.append(body.get("ticker", ""))
|
||||
return _json_response(201, {"id": company_id, "ticker": body.get("ticker", "")})
|
||||
|
||||
if method == "POST" and "/sources" in url:
|
||||
return _json_response(201, {"id": "src-1"})
|
||||
|
||||
if method == "GET" and "/watchlists" in url:
|
||||
return _json_response(200, [{"id": "wl-1", "active": True}])
|
||||
|
||||
if method == "POST" and "/members/" in url:
|
||||
return _json_response(201, {})
|
||||
|
||||
return _text_response(404)
|
||||
|
||||
with _patched_client(handler):
|
||||
auto_registered, cid = await auto_register_symbol("tsla", REGISTRY_BASE)
|
||||
|
||||
assert auto_registered is True
|
||||
assert received_tickers == ["TSLA"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Override Endpoint Tests (Task 3.1)
|
||||
# ===========================================================================
|
||||
#
|
||||
# Unit tests for POST /api/trading/override/order endpoint.
|
||||
# Requirements: 3.1, 3.2, 3.4, 3.5, 9.1
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch as _patch
|
||||
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
import services.trading.app as _app_module
|
||||
from services.trading.app import app
|
||||
|
||||
|
||||
def _make_fake_engine(redis_rpush_side_effect=None):
|
||||
"""Create a fake engine object with a mock Redis client."""
|
||||
fake_engine = MagicMock()
|
||||
fake_engine.running = True
|
||||
fake_engine.redis = AsyncMock()
|
||||
fake_engine.redis.rpush = AsyncMock(side_effect=redis_rpush_side_effect)
|
||||
return fake_engine
|
||||
|
||||
|
||||
def _override_client(fake_engine):
|
||||
"""Return a TestClient with the module-level engine replaced."""
|
||||
original = _app_module.engine
|
||||
_app_module.engine = fake_engine
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
return client, original
|
||||
|
||||
|
||||
def _restore_engine(original):
|
||||
_app_module.engine = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Valid order returns 202 with correct response shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOverrideEndpointValid:
|
||||
"""Requirement 3.1, 3.4: Valid order returns 202 with correct response shape."""
|
||||
|
||||
def test_valid_market_order_returns_202(self):
|
||||
"""A valid market order returns 202 with job_id, status, ticker, side, quantity, auto_registered."""
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
|
||||
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "AAPL"}])):
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "AAPL",
|
||||
"side": "buy",
|
||||
"quantity": 10.0,
|
||||
"order_type": "market",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert "job_id" in data
|
||||
assert data["status"] == "queued"
|
||||
assert data["ticker"] == "AAPL"
|
||||
assert data["side"] == "buy"
|
||||
assert data["quantity"] == 10.0
|
||||
assert isinstance(data["auto_registered"], bool)
|
||||
assert data["job_id"].startswith("override-")
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_valid_limit_order_returns_202(self):
|
||||
"""A valid limit order with limit_price returns 202."""
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
|
||||
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "TSLA"}])):
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "TSLA",
|
||||
"side": "sell",
|
||||
"quantity": 5.0,
|
||||
"order_type": "limit",
|
||||
"limit_price": 150.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert data["ticker"] == "TSLA"
|
||||
assert data["side"] == "sell"
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Invalid ticker returns 422
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOverrideEndpointInvalidTicker:
|
||||
"""Requirement 3.5: Invalid ticker format returns 422."""
|
||||
|
||||
def test_numeric_ticker_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={"ticker": "123", "side": "buy", "quantity": 1.0},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_empty_ticker_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={"ticker": "", "side": "buy", "quantity": 1.0},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_ticker_with_spaces_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={"ticker": "AA BB", "side": "buy", "quantity": 1.0},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_ticker_too_long_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={"ticker": "ABCDEFGHIJK", "side": "buy", "quantity": 1.0},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Missing limit_price for limit order returns 422
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOverrideEndpointMissingLimitPrice:
|
||||
"""Requirement 3.5: Missing limit_price for limit/stop_limit orders returns 422."""
|
||||
|
||||
def test_limit_order_without_limit_price_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "AAPL",
|
||||
"side": "buy",
|
||||
"quantity": 10.0,
|
||||
"order_type": "limit",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_stop_limit_order_without_limit_price_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "AAPL",
|
||||
"side": "buy",
|
||||
"quantity": 10.0,
|
||||
"order_type": "stop_limit",
|
||||
"stop_price": 100.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. Missing stop_price for stop order returns 422
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOverrideEndpointMissingStopPrice:
|
||||
"""Requirement 3.5: Missing stop_price for stop/stop_limit orders returns 422."""
|
||||
|
||||
def test_stop_order_without_stop_price_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "AAPL",
|
||||
"side": "buy",
|
||||
"quantity": 10.0,
|
||||
"order_type": "stop",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_stop_limit_order_without_stop_price_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "AAPL",
|
||||
"side": "buy",
|
||||
"quantity": 10.0,
|
||||
"order_type": "stop_limit",
|
||||
"limit_price": 150.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. Non-positive quantity returns 422
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOverrideEndpointNonPositiveQuantity:
|
||||
"""Requirement 3.5: Non-positive quantity returns 422."""
|
||||
|
||||
def test_zero_quantity_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={"ticker": "AAPL", "side": "buy", "quantity": 0.0},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_negative_quantity_returns_422(self):
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={"ticker": "AAPL", "side": "buy", "quantity": -5.0},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11. Enqueued job has correct structure and source: "manual_override"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOverrideEndpointJobStructure:
|
||||
"""Requirement 3.2, 9.1: Enqueued job has correct structure and source marker."""
|
||||
|
||||
def test_enqueued_job_has_correct_fields(self):
|
||||
"""The JSON payload RPUSH'd to Redis contains all required fields."""
|
||||
captured_payloads: list[str] = []
|
||||
|
||||
async def capture_rpush(key, value):
|
||||
captured_payloads.append(value)
|
||||
|
||||
fake_engine = _make_fake_engine()
|
||||
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
|
||||
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "MSFT"}])):
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "MSFT",
|
||||
"side": "buy",
|
||||
"quantity": 25.0,
|
||||
"order_type": "market",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
assert len(captured_payloads) == 1
|
||||
|
||||
payload = json.loads(captured_payloads[0])
|
||||
assert payload["ticker"] == "MSFT"
|
||||
assert payload["side"] == "buy"
|
||||
assert payload["quantity"] == 25.0
|
||||
assert payload["order_type"] == "market"
|
||||
assert payload["source"] == "manual_override"
|
||||
assert payload["idempotency_key"].startswith("override-")
|
||||
assert payload["limit_price"] is None
|
||||
assert payload["stop_price"] is None
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_enqueued_job_includes_limit_and_stop_prices(self):
|
||||
"""When limit_price and stop_price are provided, they appear in the payload."""
|
||||
captured_payloads: list[str] = []
|
||||
|
||||
async def capture_rpush(key, value):
|
||||
captured_payloads.append(value)
|
||||
|
||||
fake_engine = _make_fake_engine()
|
||||
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(True, "comp-2")):
|
||||
with _patched_client(lambda req: _json_response(200, [])):
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "GOOG",
|
||||
"side": "sell",
|
||||
"quantity": 3.0,
|
||||
"order_type": "stop_limit",
|
||||
"limit_price": 140.0,
|
||||
"stop_price": 135.0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
assert len(captured_payloads) == 1
|
||||
|
||||
payload = json.loads(captured_payloads[0])
|
||||
assert payload["source"] == "manual_override"
|
||||
assert payload["limit_price"] == 140.0
|
||||
assert payload["stop_price"] == 135.0
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
def test_job_id_matches_response(self):
|
||||
"""The job_id in the response matches the idempotency_key in the enqueued payload."""
|
||||
captured_payloads: list[str] = []
|
||||
|
||||
async def capture_rpush(key, value):
|
||||
captured_payloads.append(value)
|
||||
|
||||
fake_engine = _make_fake_engine()
|
||||
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
|
||||
client, original = _override_client(fake_engine)
|
||||
try:
|
||||
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
|
||||
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "AMD"}])):
|
||||
resp = client.post(
|
||||
"/api/trading/override/order",
|
||||
json={
|
||||
"ticker": "AMD",
|
||||
"side": "buy",
|
||||
"quantity": 1.0,
|
||||
"order_type": "market",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
payload = json.loads(captured_payloads[0])
|
||||
assert data["job_id"] == payload["idempotency_key"]
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
@@ -0,0 +1,334 @@
|
||||
"""Property-based tests for the Override Trade Tab.
|
||||
|
||||
Feature: override-trade-tab
|
||||
|
||||
Property 1: Ticker validation and normalization
|
||||
Property 2: Override job payload completeness
|
||||
Property 3: Invalid override order rejection
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
from pydantic import ValidationError
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
import services.trading.app as _app_module
|
||||
from services.trading.app import OverrideOrderRequest, app
|
||||
|
||||
TICKER_PATTERN = re.compile(r"^[A-Z]{1,10}$")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_fake_engine(redis_rpush_side_effect=None):
|
||||
"""Create a fake engine object with a mock Redis client."""
|
||||
fake_engine = MagicMock()
|
||||
fake_engine.running = True
|
||||
fake_engine.redis = AsyncMock()
|
||||
fake_engine.redis.rpush = AsyncMock(side_effect=redis_rpush_side_effect)
|
||||
return fake_engine
|
||||
|
||||
|
||||
def _override_client(fake_engine):
|
||||
"""Return a TestClient with the module-level engine replaced."""
|
||||
original = _app_module.engine
|
||||
_app_module.engine = fake_engine
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
return client, original
|
||||
|
||||
|
||||
def _restore_engine(original):
|
||||
_app_module.engine = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Strategy for valid tickers: 1-10 uppercase alpha characters
|
||||
valid_ticker_st = st.from_regex(r"[A-Za-z]{1,10}", fullmatch=True)
|
||||
|
||||
# Strategy for valid sides
|
||||
valid_side_st = st.sampled_from(["buy", "sell"])
|
||||
|
||||
# Strategy for valid positive quantities
|
||||
valid_quantity_st = st.floats(min_value=0.01, max_value=1_000_000.0, allow_nan=False, allow_infinity=False)
|
||||
|
||||
# Strategy for valid order types
|
||||
valid_order_type_st = st.sampled_from(["market", "limit", "stop", "stop_limit"])
|
||||
|
||||
# Strategy for valid positive prices
|
||||
valid_price_st = st.floats(min_value=0.01, max_value=1_000_000.0, allow_nan=False, allow_infinity=False)
|
||||
|
||||
|
||||
@st.composite
|
||||
def valid_override_order_st(draw):
|
||||
"""Generate a valid override order request dict."""
|
||||
ticker = draw(valid_ticker_st)
|
||||
side = draw(valid_side_st)
|
||||
quantity = draw(valid_quantity_st)
|
||||
order_type = draw(valid_order_type_st)
|
||||
|
||||
order = {
|
||||
"ticker": ticker,
|
||||
"side": side,
|
||||
"quantity": quantity,
|
||||
"order_type": order_type,
|
||||
}
|
||||
|
||||
if order_type in ("limit", "stop_limit"):
|
||||
order["limit_price"] = draw(valid_price_st)
|
||||
if order_type in ("stop", "stop_limit"):
|
||||
order["stop_price"] = draw(valid_price_st)
|
||||
|
||||
return order
|
||||
|
||||
|
||||
@st.composite
|
||||
def invalid_override_order_st(draw):
|
||||
"""Generate an override order request that violates at least one validation rule.
|
||||
|
||||
Possible violations:
|
||||
- Invalid ticker format (digits, spaces, special chars, empty, too long)
|
||||
- Non-positive quantity (zero or negative)
|
||||
- Missing limit_price for limit/stop_limit orders
|
||||
- Missing stop_price for stop/stop_limit orders
|
||||
"""
|
||||
violation = draw(st.sampled_from([
|
||||
"bad_ticker",
|
||||
"non_positive_quantity",
|
||||
"missing_limit_price",
|
||||
"missing_stop_price",
|
||||
]))
|
||||
|
||||
if violation == "bad_ticker":
|
||||
# Generate a ticker that does NOT match ^[A-Z]{1,10}$ after uppercasing
|
||||
bad_ticker = draw(st.sampled_from([
|
||||
"", # empty
|
||||
"ABCDEFGHIJK", # 11 chars — too long
|
||||
draw(st.from_regex(r"[0-9]{1,5}", fullmatch=True)), # digits
|
||||
draw(st.from_regex(r"[A-Z]{1,5} [A-Z]{1,5}", fullmatch=True)), # spaces
|
||||
draw(st.from_regex(r"[A-Z]{1,5}[^A-Za-z0-9]", fullmatch=True)), # special char
|
||||
]))
|
||||
return {
|
||||
"ticker": bad_ticker,
|
||||
"side": "buy",
|
||||
"quantity": 10.0,
|
||||
"order_type": "market",
|
||||
}
|
||||
|
||||
if violation == "non_positive_quantity":
|
||||
qty = draw(st.floats(min_value=-1_000_000.0, max_value=0.0, allow_nan=False, allow_infinity=False))
|
||||
return {
|
||||
"ticker": "AAPL",
|
||||
"side": "buy",
|
||||
"quantity": qty,
|
||||
"order_type": "market",
|
||||
}
|
||||
|
||||
if violation == "missing_limit_price":
|
||||
order_type = draw(st.sampled_from(["limit", "stop_limit"]))
|
||||
order = {
|
||||
"ticker": "AAPL",
|
||||
"side": "buy",
|
||||
"quantity": 10.0,
|
||||
"order_type": order_type,
|
||||
}
|
||||
# For stop_limit, provide stop_price but NOT limit_price
|
||||
if order_type == "stop_limit":
|
||||
order["stop_price"] = 100.0
|
||||
return order
|
||||
|
||||
# violation == "missing_stop_price"
|
||||
order_type = draw(st.sampled_from(["stop", "stop_limit"]))
|
||||
order = {
|
||||
"ticker": "AAPL",
|
||||
"side": "buy",
|
||||
"quantity": 10.0,
|
||||
"order_type": order_type,
|
||||
}
|
||||
# For stop_limit, provide limit_price but NOT stop_price
|
||||
if order_type == "stop_limit":
|
||||
order["limit_price"] = 150.0
|
||||
return order
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 1: Ticker validation and normalization
|
||||
# **Validates: Requirements 2.2, 8.1**
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty1TickerValidationAndNormalization:
|
||||
"""Property 1: Ticker validation and normalization.
|
||||
|
||||
For any string input, the ticker validation function SHALL accept it
|
||||
if and only if, after uppercasing, it matches ^[A-Z]{1,10}$.
|
||||
The normalized output SHALL always be the uppercased version of the input.
|
||||
|
||||
**Validates: Requirements 2.2, 8.1**
|
||||
"""
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(ticker=st.text(min_size=0, max_size=20))
|
||||
def test_ticker_accepted_iff_matches_pattern(self, ticker: str) -> None:
|
||||
"""After uppercasing, accepted iff matches ^[A-Z]{1,10}$;
|
||||
normalized output is always uppercased input."""
|
||||
uppercased = ticker.upper()
|
||||
should_accept = bool(TICKER_PATTERN.match(uppercased))
|
||||
|
||||
try:
|
||||
req = OverrideOrderRequest(
|
||||
ticker=ticker,
|
||||
side="buy",
|
||||
quantity=1.0,
|
||||
order_type="market",
|
||||
)
|
||||
# Accepted — verify it should have been accepted
|
||||
assert should_accept, (
|
||||
f"Ticker {ticker!r} was accepted but uppercased form "
|
||||
f"{uppercased!r} does not match ^[A-Z]{{1,10}}$"
|
||||
)
|
||||
# Normalized output is always uppercased
|
||||
assert req.ticker == uppercased, (
|
||||
f"Expected normalized ticker {uppercased!r}, got {req.ticker!r}"
|
||||
)
|
||||
except ValidationError:
|
||||
# Rejected — verify it should have been rejected
|
||||
assert not should_accept, (
|
||||
f"Ticker {ticker!r} was rejected but uppercased form "
|
||||
f"{uppercased!r} matches ^[A-Z]{{1,10}}$"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 2: Override job payload completeness
|
||||
# **Validates: Requirements 3.2, 9.1**
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty2OverrideJobPayloadCompleteness:
|
||||
"""Property 2: Override job payload completeness.
|
||||
|
||||
For any valid override order request, the job payload enqueued to the
|
||||
broker queue SHALL contain all required fields, source == "manual_override",
|
||||
and idempotency_key starts with "override-".
|
||||
|
||||
**Validates: Requirements 3.2, 9.1**
|
||||
"""
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(order=valid_override_order_st())
|
||||
def test_enqueued_payload_has_all_required_fields(self, order: dict) -> None:
|
||||
"""Enqueued payload contains all required fields with correct values."""
|
||||
captured_payloads: list[str] = []
|
||||
|
||||
async def capture_rpush(key, value):
|
||||
captured_payloads.append(value)
|
||||
|
||||
fake_engine = _make_fake_engine()
|
||||
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
|
||||
client, original = _override_client(fake_engine)
|
||||
|
||||
try:
|
||||
with patch.object(
|
||||
_app_module,
|
||||
"auto_register_symbol",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(False, "comp-1"),
|
||||
):
|
||||
import httpx
|
||||
|
||||
def _mock_handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
200,
|
||||
content=json.dumps([{"id": "comp-1", "ticker": order["ticker"].upper()}]).encode(),
|
||||
headers={"content-type": "application/json"},
|
||||
)
|
||||
|
||||
original_init = httpx.AsyncClient.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
kwargs.pop("timeout", None)
|
||||
kwargs["transport"] = httpx.MockTransport(_mock_handler)
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
with patch.object(httpx.AsyncClient, "__init__", patched_init):
|
||||
resp = client.post("/api/trading/override/order", json=order)
|
||||
|
||||
assert resp.status_code == 202, (
|
||||
f"Expected 202 for valid order {order!r}, got {resp.status_code}: {resp.text}"
|
||||
)
|
||||
assert len(captured_payloads) == 1, "Expected exactly one RPUSH call"
|
||||
|
||||
payload = json.loads(captured_payloads[0])
|
||||
|
||||
# All required fields present
|
||||
expected_ticker = order["ticker"].upper()
|
||||
assert payload["ticker"] == expected_ticker
|
||||
assert payload["side"] == order["side"]
|
||||
assert payload["quantity"] == order["quantity"]
|
||||
assert payload["order_type"] == order["order_type"]
|
||||
assert payload["source"] == "manual_override"
|
||||
assert isinstance(payload["idempotency_key"], str)
|
||||
assert payload["idempotency_key"].startswith("override-")
|
||||
|
||||
# Conditional price fields
|
||||
if "limit_price" in order:
|
||||
assert payload["limit_price"] == order["limit_price"]
|
||||
if "stop_price" in order:
|
||||
assert payload["stop_price"] == order["stop_price"]
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Property 3: Invalid override order rejection
|
||||
# **Validates: Requirements 3.5, 2.6**
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProperty3InvalidOverrideOrderRejection:
|
||||
"""Property 3: Invalid override order rejection.
|
||||
|
||||
For any override order request that violates at least one validation rule,
|
||||
the endpoint SHALL return a 422 status code and the response body SHALL
|
||||
contain at least one descriptive error message.
|
||||
|
||||
**Validates: Requirements 3.5, 2.6**
|
||||
"""
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(order=invalid_override_order_st())
|
||||
def test_invalid_order_returns_422_with_error_message(self, order: dict) -> None:
|
||||
"""Invalid orders return 422 with at least one descriptive error."""
|
||||
fake_engine = _make_fake_engine()
|
||||
client, original = _override_client(fake_engine)
|
||||
|
||||
try:
|
||||
resp = client.post("/api/trading/override/order", json=order)
|
||||
|
||||
assert resp.status_code == 422, (
|
||||
f"Expected 422 for invalid order {order!r}, got {resp.status_code}: {resp.text}"
|
||||
)
|
||||
|
||||
body = resp.json()
|
||||
# FastAPI returns validation errors in a "detail" field
|
||||
assert "detail" in body, (
|
||||
f"Expected 'detail' in 422 response body, got: {body}"
|
||||
)
|
||||
# At least one error message
|
||||
detail = body["detail"]
|
||||
assert isinstance(detail, list) and len(detail) >= 1, (
|
||||
f"Expected at least one validation error, got: {detail}"
|
||||
)
|
||||
finally:
|
||||
_restore_engine(original)
|
||||
Reference in New Issue
Block a user