Files
stonks-oracle/tests/test_override.py
T
Celes Renata 913fe8b0b3 feat: override trade tab — manual order entry with auto-registration
Backend:
- OverrideOrderRequest/Response Pydantic models with ticker, quantity, price validators
- POST /api/trading/override/order endpoint (enqueue to Redis broker queue)
- auto_register_symbol() module for untracked ticker registration via Symbol Registry
- Unit tests (17) and property-based tests (3 x 100 examples)

Frontend:
- OverrideTradePanel component (order form + positions display)
- Override tab in TradingEngine page with URL search param navigation
- Override Trade button on Trading Controls page
- useSubmitOverrideOrder mutation hook
- MSW handler and 13 component/integration tests

Steering:
- Updated steering docs for Ubuntu dev machine with nvm/Node 24
2026-04-17 07:02:30 +00:00

813 lines
30 KiB
Python

"""Unit tests for auto_register_symbol() in services/trading/override.py.
Validates:
- New symbol registration (company + sources + watchlist)
- Existing symbol skip
- 409 conflict handling
- Source/watchlist failure tolerance
Requirements: 4.1, 4.2, 4.3, 4.5, 4.6
"""
from __future__ import annotations
import json
from contextlib import asynccontextmanager
from unittest.mock import patch
import httpx
import pytest
from services.trading.override import auto_register_symbol
REGISTRY_BASE = "http://symbol-registry:8000"
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _json_response(status_code: int, json_data) -> httpx.Response:
"""Build an httpx.Response with JSON body for MockTransport."""
return httpx.Response(
status_code,
content=json.dumps(json_data).encode(),
headers={"content-type": "application/json"},
)
def _text_response(status_code: int, text: str = "") -> httpx.Response:
"""Build an httpx.Response with text body for MockTransport."""
return httpx.Response(status_code, text=text)
def _patched_client(handler):
"""Return a patch context manager that replaces httpx.AsyncClient with a mock-transport client.
The handler is a sync callable: ``(request: httpx.Request) -> httpx.Response``.
"""
original_init = httpx.AsyncClient.__init__
def patched_init(self, *args, **kwargs):
kwargs.pop("timeout", None)
kwargs["transport"] = httpx.MockTransport(handler)
original_init(self, *args, **kwargs)
return patch.object(httpx.AsyncClient, "__init__", patched_init)
# ---------------------------------------------------------------------------
# 1. New symbol registration — full happy path
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestAutoRegisterNewSymbol:
"""Requirement 4.1, 4.2, 4.3: New symbol creates company, sources, watchlist membership."""
async def test_new_symbol_full_registration(self):
"""GET /companies returns empty → POST /companies 201 → sources → watchlist → (True, id)."""
company_id = "aaaaaaaa-1111-2222-3333-444444444444"
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
if method == "GET" and "/companies" in url and "/watchlists" not in url:
return _json_response(200, [])
if method == "POST" and url.endswith("/companies"):
return _json_response(201, {"id": company_id, "ticker": "TSLA"})
if method == "POST" and "/sources" in url:
return _json_response(201, {"id": "src-1"})
if method == "GET" and "/watchlists" in url:
return _json_response(200, [{"id": "wl-1", "name": "Default", "active": True}])
if method == "POST" and "/members/" in url:
return _json_response(201, {})
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("TSLA", REGISTRY_BASE)
assert auto_registered is True
assert cid == company_id
async def test_new_symbol_creates_watchlist_when_none_exist(self):
"""When no active watchlists exist, creates 'Manual Overrides' watchlist."""
company_id = "bbbbbbbb-1111-2222-3333-444444444444"
watchlist_id = "wl-new-1"
call_log: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
call_log.append(f"{method} {url}")
if method == "GET" and "/companies" in url and "/watchlists" not in url:
return _json_response(200, [])
if method == "POST" and url.endswith("/companies"):
return _json_response(201, {"id": company_id, "ticker": "NVDA"})
if method == "POST" and "/sources" in url:
return _json_response(201, {"id": "src-1"})
if method == "GET" and "/watchlists" in url:
return _json_response(200, []) # No active watchlists
if method == "POST" and url.endswith("/watchlists"):
return _json_response(201, {"id": watchlist_id, "name": "Manual Overrides"})
if method == "POST" and "/members/" in url:
return _json_response(201, {})
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("NVDA", REGISTRY_BASE)
assert auto_registered is True
assert cid == company_id
assert any("POST" in c and "/watchlists" in c and "/members" not in c for c in call_log)
# ---------------------------------------------------------------------------
# 2. Existing symbol skip
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestAutoRegisterExistingSymbol:
"""Requirement 4.5: Existing symbols skip registration."""
async def test_existing_symbol_returns_false(self):
"""GET /companies returns list containing the ticker → (False, company_id)."""
existing_id = "cccccccc-1111-2222-3333-444444444444"
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if request.method == "GET" and "/companies" in url:
return _json_response(200, [
{"id": existing_id, "ticker": "AAPL", "legal_name": "Apple Inc"},
{"id": "other-id", "ticker": "MSFT", "legal_name": "Microsoft"},
])
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("AAPL", REGISTRY_BASE)
assert auto_registered is False
assert cid == existing_id
async def test_existing_symbol_no_post_calls(self):
"""When ticker exists, no POST calls should be made."""
existing_id = "dddddddd-1111-2222-3333-444444444444"
post_calls: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if request.method == "POST":
post_calls.append(url)
if request.method == "GET" and "/companies" in url:
return _json_response(200, [{"id": existing_id, "ticker": "GOOG"}])
return _text_response(404)
with _patched_client(handler):
await auto_register_symbol("GOOG", REGISTRY_BASE)
assert post_calls == [], f"Unexpected POST calls: {post_calls}"
# ---------------------------------------------------------------------------
# 3. 409 conflict handling
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestAutoRegister409Conflict:
"""Requirement 4.6: 409 conflict treated as success."""
async def test_409_conflict_fetches_existing_company(self):
"""POST /companies 409 → fetches existing company → returns with company_id."""
existing_id = "eeeeeeee-1111-2222-3333-444444444444"
get_call_count = 0
def handler(request: httpx.Request) -> httpx.Response:
nonlocal get_call_count
url = str(request.url)
method = request.method
if method == "GET" and "/companies" in url and "/watchlists" not in url:
get_call_count += 1
if get_call_count == 1:
return _json_response(200, [])
else:
return _json_response(200, [{"id": existing_id, "ticker": "AMD"}])
if method == "POST" and url.endswith("/companies"):
return _json_response(409, {"detail": "Company already exists"})
if method == "POST" and "/sources" in url:
return _json_response(201, {"id": "src-1"})
if method == "GET" and "/watchlists" in url:
return _json_response(200, [{"id": "wl-1", "active": True}])
if method == "POST" and "/members/" in url:
return _json_response(201, {})
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("AMD", REGISTRY_BASE)
assert auto_registered is True
assert cid == existing_id
async def test_409_conflict_but_cannot_find_company_returns_false(self):
"""POST /companies 409 but re-fetch can't find company → (False, '')."""
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
if request.method == "GET" and "/companies" in url:
return _json_response(200, [])
if request.method == "POST" and url.endswith("/companies"):
return _json_response(409, {"detail": "Conflict"})
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("XYZ", REGISTRY_BASE)
assert auto_registered is False
assert cid == ""
# ---------------------------------------------------------------------------
# 4. Source creation failure tolerance
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestSourceFailureTolerance:
"""Source creation failures are best-effort — should not block registration."""
async def test_source_creation_500_still_succeeds(self):
"""POST /companies/{id}/sources returns 500 → still returns (True, company_id)."""
company_id = "ffffffff-1111-2222-3333-444444444444"
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
if method == "GET" and "/companies" in url and "/watchlists" not in url:
return _json_response(200, [])
if method == "POST" and url.endswith("/companies"):
return _json_response(201, {"id": company_id, "ticker": "PLTR"})
if method == "POST" and "/sources" in url:
return _text_response(500, "Internal Server Error")
if method == "GET" and "/watchlists" in url:
return _json_response(200, [{"id": "wl-1", "active": True}])
if method == "POST" and "/members/" in url:
return _json_response(201, {})
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("PLTR", REGISTRY_BASE)
assert auto_registered is True
assert cid == company_id
async def test_source_creation_network_error_still_succeeds(self):
"""Source creation raises network error → still returns successfully."""
company_id = "11111111-aaaa-bbbb-cccc-dddddddddddd"
source_call_count = 0
def handler(request: httpx.Request) -> httpx.Response:
nonlocal source_call_count
url = str(request.url)
method = request.method
if method == "GET" and "/companies" in url and "/watchlists" not in url:
return _json_response(200, [])
if method == "POST" and url.endswith("/companies"):
return _json_response(201, {"id": company_id, "ticker": "RIVN"})
if method == "POST" and "/sources" in url:
source_call_count += 1
raise httpx.ConnectError("connection refused")
if method == "GET" and "/watchlists" in url:
return _json_response(200, [{"id": "wl-1", "active": True}])
if method == "POST" and "/members/" in url:
return _json_response(201, {})
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("RIVN", REGISTRY_BASE)
assert auto_registered is True
assert cid == company_id
assert source_call_count == 2
# ---------------------------------------------------------------------------
# 5. Watchlist failure tolerance
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestWatchlistFailureTolerance:
"""Watchlist failures are best-effort — should not block registration."""
async def test_watchlist_get_fails_still_succeeds(self):
"""GET /watchlists raises exception → still returns (True, company_id)."""
company_id = "22222222-aaaa-bbbb-cccc-dddddddddddd"
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
if method == "GET" and "/companies" in url and "/watchlists" not in url:
return _json_response(200, [])
if method == "POST" and url.endswith("/companies"):
return _json_response(201, {"id": company_id, "ticker": "SOFI"})
if method == "POST" and "/sources" in url:
return _json_response(201, {"id": "src-1"})
if method == "GET" and "/watchlists" in url:
raise httpx.ConnectError("connection refused")
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("SOFI", REGISTRY_BASE)
assert auto_registered is True
assert cid == company_id
async def test_watchlist_member_add_fails_still_succeeds(self):
"""POST /watchlists/{id}/members/{cid} returns 500 → still succeeds."""
company_id = "33333333-aaaa-bbbb-cccc-dddddddddddd"
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
if method == "GET" and "/companies" in url and "/watchlists" not in url:
return _json_response(200, [])
if method == "POST" and url.endswith("/companies"):
return _json_response(201, {"id": company_id, "ticker": "COIN"})
if method == "POST" and "/sources" in url:
return _json_response(201, {"id": "src-1"})
if method == "GET" and "/watchlists" in url:
return _json_response(200, [{"id": "wl-1", "active": True}])
if method == "POST" and "/members/" in url:
return _text_response(500, "Internal Server Error")
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("COIN", REGISTRY_BASE)
assert auto_registered is True
assert cid == company_id
async def test_ticker_is_uppercased(self):
"""Lowercase ticker input is normalized to uppercase."""
company_id = "44444444-aaaa-bbbb-cccc-dddddddddddd"
received_tickers: list[str] = []
def handler(request: httpx.Request) -> httpx.Response:
url = str(request.url)
method = request.method
if method == "GET" and "/companies" in url and "/watchlists" not in url:
return _json_response(200, [])
if method == "POST" and url.endswith("/companies"):
body = json.loads(request.content)
received_tickers.append(body.get("ticker", ""))
return _json_response(201, {"id": company_id, "ticker": body.get("ticker", "")})
if method == "POST" and "/sources" in url:
return _json_response(201, {"id": "src-1"})
if method == "GET" and "/watchlists" in url:
return _json_response(200, [{"id": "wl-1", "active": True}])
if method == "POST" and "/members/" in url:
return _json_response(201, {})
return _text_response(404)
with _patched_client(handler):
auto_registered, cid = await auto_register_symbol("tsla", REGISTRY_BASE)
assert auto_registered is True
assert received_tickers == ["TSLA"]
# ===========================================================================
# Override Endpoint Tests (Task 3.1)
# ===========================================================================
#
# Unit tests for POST /api/trading/override/order endpoint.
# Requirements: 3.1, 3.2, 3.4, 3.5, 9.1
from unittest.mock import AsyncMock, MagicMock, patch as _patch
from starlette.testclient import TestClient
import services.trading.app as _app_module
from services.trading.app import app
def _make_fake_engine(redis_rpush_side_effect=None):
"""Create a fake engine object with a mock Redis client."""
fake_engine = MagicMock()
fake_engine.running = True
fake_engine.redis = AsyncMock()
fake_engine.redis.rpush = AsyncMock(side_effect=redis_rpush_side_effect)
return fake_engine
def _override_client(fake_engine):
"""Return a TestClient with the module-level engine replaced."""
original = _app_module.engine
_app_module.engine = fake_engine
client = TestClient(app, raise_server_exceptions=False)
return client, original
def _restore_engine(original):
_app_module.engine = original
# ---------------------------------------------------------------------------
# 6. Valid order returns 202 with correct response shape
# ---------------------------------------------------------------------------
class TestOverrideEndpointValid:
"""Requirement 3.1, 3.4: Valid order returns 202 with correct response shape."""
def test_valid_market_order_returns_202(self):
"""A valid market order returns 202 with job_id, status, ticker, side, quantity, auto_registered."""
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "AAPL"}])):
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "AAPL",
"side": "buy",
"quantity": 10.0,
"order_type": "market",
},
)
assert resp.status_code == 202
data = resp.json()
assert "job_id" in data
assert data["status"] == "queued"
assert data["ticker"] == "AAPL"
assert data["side"] == "buy"
assert data["quantity"] == 10.0
assert isinstance(data["auto_registered"], bool)
assert data["job_id"].startswith("override-")
finally:
_restore_engine(original)
def test_valid_limit_order_returns_202(self):
"""A valid limit order with limit_price returns 202."""
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "TSLA"}])):
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "TSLA",
"side": "sell",
"quantity": 5.0,
"order_type": "limit",
"limit_price": 150.0,
},
)
assert resp.status_code == 202
data = resp.json()
assert data["ticker"] == "TSLA"
assert data["side"] == "sell"
finally:
_restore_engine(original)
# ---------------------------------------------------------------------------
# 7. Invalid ticker returns 422
# ---------------------------------------------------------------------------
class TestOverrideEndpointInvalidTicker:
"""Requirement 3.5: Invalid ticker format returns 422."""
def test_numeric_ticker_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={"ticker": "123", "side": "buy", "quantity": 1.0},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
def test_empty_ticker_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={"ticker": "", "side": "buy", "quantity": 1.0},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
def test_ticker_with_spaces_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={"ticker": "AA BB", "side": "buy", "quantity": 1.0},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
def test_ticker_too_long_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={"ticker": "ABCDEFGHIJK", "side": "buy", "quantity": 1.0},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
# ---------------------------------------------------------------------------
# 8. Missing limit_price for limit order returns 422
# ---------------------------------------------------------------------------
class TestOverrideEndpointMissingLimitPrice:
"""Requirement 3.5: Missing limit_price for limit/stop_limit orders returns 422."""
def test_limit_order_without_limit_price_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "AAPL",
"side": "buy",
"quantity": 10.0,
"order_type": "limit",
},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
def test_stop_limit_order_without_limit_price_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "AAPL",
"side": "buy",
"quantity": 10.0,
"order_type": "stop_limit",
"stop_price": 100.0,
},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
# ---------------------------------------------------------------------------
# 9. Missing stop_price for stop order returns 422
# ---------------------------------------------------------------------------
class TestOverrideEndpointMissingStopPrice:
"""Requirement 3.5: Missing stop_price for stop/stop_limit orders returns 422."""
def test_stop_order_without_stop_price_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "AAPL",
"side": "buy",
"quantity": 10.0,
"order_type": "stop",
},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
def test_stop_limit_order_without_stop_price_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "AAPL",
"side": "buy",
"quantity": 10.0,
"order_type": "stop_limit",
"limit_price": 150.0,
},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
# ---------------------------------------------------------------------------
# 10. Non-positive quantity returns 422
# ---------------------------------------------------------------------------
class TestOverrideEndpointNonPositiveQuantity:
"""Requirement 3.5: Non-positive quantity returns 422."""
def test_zero_quantity_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={"ticker": "AAPL", "side": "buy", "quantity": 0.0},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
def test_negative_quantity_returns_422(self):
fake_engine = _make_fake_engine()
client, original = _override_client(fake_engine)
try:
resp = client.post(
"/api/trading/override/order",
json={"ticker": "AAPL", "side": "buy", "quantity": -5.0},
)
assert resp.status_code == 422
finally:
_restore_engine(original)
# ---------------------------------------------------------------------------
# 11. Enqueued job has correct structure and source: "manual_override"
# ---------------------------------------------------------------------------
class TestOverrideEndpointJobStructure:
"""Requirement 3.2, 9.1: Enqueued job has correct structure and source marker."""
def test_enqueued_job_has_correct_fields(self):
"""The JSON payload RPUSH'd to Redis contains all required fields."""
captured_payloads: list[str] = []
async def capture_rpush(key, value):
captured_payloads.append(value)
fake_engine = _make_fake_engine()
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
client, original = _override_client(fake_engine)
try:
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "MSFT"}])):
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "MSFT",
"side": "buy",
"quantity": 25.0,
"order_type": "market",
},
)
assert resp.status_code == 202
assert len(captured_payloads) == 1
payload = json.loads(captured_payloads[0])
assert payload["ticker"] == "MSFT"
assert payload["side"] == "buy"
assert payload["quantity"] == 25.0
assert payload["order_type"] == "market"
assert payload["source"] == "manual_override"
assert payload["idempotency_key"].startswith("override-")
assert payload["limit_price"] is None
assert payload["stop_price"] is None
finally:
_restore_engine(original)
def test_enqueued_job_includes_limit_and_stop_prices(self):
"""When limit_price and stop_price are provided, they appear in the payload."""
captured_payloads: list[str] = []
async def capture_rpush(key, value):
captured_payloads.append(value)
fake_engine = _make_fake_engine()
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
client, original = _override_client(fake_engine)
try:
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(True, "comp-2")):
with _patched_client(lambda req: _json_response(200, [])):
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "GOOG",
"side": "sell",
"quantity": 3.0,
"order_type": "stop_limit",
"limit_price": 140.0,
"stop_price": 135.0,
},
)
assert resp.status_code == 202
assert len(captured_payloads) == 1
payload = json.loads(captured_payloads[0])
assert payload["source"] == "manual_override"
assert payload["limit_price"] == 140.0
assert payload["stop_price"] == 135.0
finally:
_restore_engine(original)
def test_job_id_matches_response(self):
"""The job_id in the response matches the idempotency_key in the enqueued payload."""
captured_payloads: list[str] = []
async def capture_rpush(key, value):
captured_payloads.append(value)
fake_engine = _make_fake_engine()
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
client, original = _override_client(fake_engine)
try:
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "AMD"}])):
resp = client.post(
"/api/trading/override/order",
json={
"ticker": "AMD",
"side": "buy",
"quantity": 1.0,
"order_type": "market",
},
)
assert resp.status_code == 202
data = resp.json()
payload = json.loads(captured_payloads[0])
assert data["job_id"] == payload["idempotency_key"]
finally:
_restore_engine(original)