c85c0068a2
- Replace all datetime.utcnow() with datetime.now(tz=timezone.utc) across 8 files - Fix 12 failing tests to match current implementation behavior - Fix pytest_plugins in non-top-level conftest (moved to root conftest.py) - Auto-fix 189 lint issues (import sorting, unused imports) - Add CI/CD pipeline infrastructure (ARC, ArgoCD, Kargo manifests) - Add values-beta.yaml and values-paper.yaml for staged deployments - Update GitHub Actions workflow to use self-hosted-gremlin runners - Add integration-test job to CI pipeline Result: 1596 passed, 0 failed, 0 warnings
813 lines
30 KiB
Python
813 lines
30 KiB
Python
"""Unit tests for auto_register_symbol() in services/trading/override.py.
|
|
|
|
Validates:
|
|
- New symbol registration (company + sources + watchlist)
|
|
- Existing symbol skip
|
|
- 409 conflict handling
|
|
- Source/watchlist failure tolerance
|
|
|
|
Requirements: 4.1, 4.2, 4.3, 4.5, 4.6
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from unittest.mock import patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from services.trading.override import auto_register_symbol
|
|
|
|
REGISTRY_BASE = "http://symbol-registry:8000"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _json_response(status_code: int, json_data) -> httpx.Response:
|
|
"""Build an httpx.Response with JSON body for MockTransport."""
|
|
return httpx.Response(
|
|
status_code,
|
|
content=json.dumps(json_data).encode(),
|
|
headers={"content-type": "application/json"},
|
|
)
|
|
|
|
|
|
def _text_response(status_code: int, text: str = "") -> httpx.Response:
|
|
"""Build an httpx.Response with text body for MockTransport."""
|
|
return httpx.Response(status_code, text=text)
|
|
|
|
|
|
def _patched_client(handler):
|
|
"""Return a patch context manager that replaces httpx.AsyncClient with a mock-transport client.
|
|
|
|
The handler is a sync callable: ``(request: httpx.Request) -> httpx.Response``.
|
|
"""
|
|
original_init = httpx.AsyncClient.__init__
|
|
|
|
def patched_init(self, *args, **kwargs):
|
|
kwargs.pop("timeout", None)
|
|
kwargs["transport"] = httpx.MockTransport(handler)
|
|
original_init(self, *args, **kwargs)
|
|
|
|
return patch.object(httpx.AsyncClient, "__init__", patched_init)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 1. New symbol registration — full happy path
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestAutoRegisterNewSymbol:
|
|
"""Requirement 4.1, 4.2, 4.3: New symbol creates company, sources, watchlist membership."""
|
|
|
|
async def test_new_symbol_full_registration(self):
|
|
"""GET /companies returns empty → POST /companies 201 → sources → watchlist → (True, id)."""
|
|
company_id = "aaaaaaaa-1111-2222-3333-444444444444"
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
method = request.method
|
|
|
|
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
|
return _json_response(200, [])
|
|
|
|
if method == "POST" and url.endswith("/companies"):
|
|
return _json_response(201, {"id": company_id, "ticker": "TSLA"})
|
|
|
|
if method == "POST" and "/sources" in url:
|
|
return _json_response(201, {"id": "src-1"})
|
|
|
|
if method == "GET" and "/watchlists" in url:
|
|
return _json_response(200, [{"id": "wl-1", "name": "Default", "active": True}])
|
|
|
|
if method == "POST" and "/members/" in url:
|
|
return _json_response(201, {})
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("TSLA", REGISTRY_BASE)
|
|
|
|
assert auto_registered is True
|
|
assert cid == company_id
|
|
|
|
async def test_new_symbol_creates_watchlist_when_none_exist(self):
|
|
"""When no active watchlists exist, creates 'Manual Overrides' watchlist."""
|
|
company_id = "bbbbbbbb-1111-2222-3333-444444444444"
|
|
watchlist_id = "wl-new-1"
|
|
call_log: list[str] = []
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
method = request.method
|
|
call_log.append(f"{method} {url}")
|
|
|
|
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
|
return _json_response(200, [])
|
|
|
|
if method == "POST" and url.endswith("/companies"):
|
|
return _json_response(201, {"id": company_id, "ticker": "NVDA"})
|
|
|
|
if method == "POST" and "/sources" in url:
|
|
return _json_response(201, {"id": "src-1"})
|
|
|
|
if method == "GET" and "/watchlists" in url:
|
|
return _json_response(200, []) # No active watchlists
|
|
|
|
if method == "POST" and url.endswith("/watchlists"):
|
|
return _json_response(201, {"id": watchlist_id, "name": "Manual Overrides"})
|
|
|
|
if method == "POST" and "/members/" in url:
|
|
return _json_response(201, {})
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("NVDA", REGISTRY_BASE)
|
|
|
|
assert auto_registered is True
|
|
assert cid == company_id
|
|
assert any("POST" in c and "/watchlists" in c and "/members" not in c for c in call_log)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 2. Existing symbol skip
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestAutoRegisterExistingSymbol:
|
|
"""Requirement 4.5: Existing symbols skip registration."""
|
|
|
|
async def test_existing_symbol_returns_false(self):
|
|
"""GET /companies returns list containing the ticker → (False, company_id)."""
|
|
existing_id = "cccccccc-1111-2222-3333-444444444444"
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
if request.method == "GET" and "/companies" in url:
|
|
return _json_response(200, [
|
|
{"id": existing_id, "ticker": "AAPL", "legal_name": "Apple Inc"},
|
|
{"id": "other-id", "ticker": "MSFT", "legal_name": "Microsoft"},
|
|
])
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("AAPL", REGISTRY_BASE)
|
|
|
|
assert auto_registered is False
|
|
assert cid == existing_id
|
|
|
|
async def test_existing_symbol_no_post_calls(self):
|
|
"""When ticker exists, no POST calls should be made."""
|
|
existing_id = "dddddddd-1111-2222-3333-444444444444"
|
|
post_calls: list[str] = []
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
if request.method == "POST":
|
|
post_calls.append(url)
|
|
if request.method == "GET" and "/companies" in url:
|
|
return _json_response(200, [{"id": existing_id, "ticker": "GOOG"}])
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
await auto_register_symbol("GOOG", REGISTRY_BASE)
|
|
|
|
assert post_calls == [], f"Unexpected POST calls: {post_calls}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 3. 409 conflict handling
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestAutoRegister409Conflict:
|
|
"""Requirement 4.6: 409 conflict treated as success."""
|
|
|
|
async def test_409_conflict_fetches_existing_company(self):
|
|
"""POST /companies 409 → fetches existing company → returns with company_id."""
|
|
existing_id = "eeeeeeee-1111-2222-3333-444444444444"
|
|
get_call_count = 0
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
nonlocal get_call_count
|
|
url = str(request.url)
|
|
method = request.method
|
|
|
|
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
|
get_call_count += 1
|
|
if get_call_count == 1:
|
|
return _json_response(200, [])
|
|
else:
|
|
return _json_response(200, [{"id": existing_id, "ticker": "AMD"}])
|
|
|
|
if method == "POST" and url.endswith("/companies"):
|
|
return _json_response(409, {"detail": "Company already exists"})
|
|
|
|
if method == "POST" and "/sources" in url:
|
|
return _json_response(201, {"id": "src-1"})
|
|
|
|
if method == "GET" and "/watchlists" in url:
|
|
return _json_response(200, [{"id": "wl-1", "active": True}])
|
|
|
|
if method == "POST" and "/members/" in url:
|
|
return _json_response(201, {})
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("AMD", REGISTRY_BASE)
|
|
|
|
assert auto_registered is True
|
|
assert cid == existing_id
|
|
|
|
async def test_409_conflict_but_cannot_find_company_returns_false(self):
|
|
"""POST /companies 409 but re-fetch can't find company → (False, '')."""
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
|
|
if request.method == "GET" and "/companies" in url:
|
|
return _json_response(200, [])
|
|
|
|
if request.method == "POST" and url.endswith("/companies"):
|
|
return _json_response(409, {"detail": "Conflict"})
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("XYZ", REGISTRY_BASE)
|
|
|
|
assert auto_registered is False
|
|
assert cid == ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 4. Source creation failure tolerance
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestSourceFailureTolerance:
|
|
"""Source creation failures are best-effort — should not block registration."""
|
|
|
|
async def test_source_creation_500_still_succeeds(self):
|
|
"""POST /companies/{id}/sources returns 500 → still returns (True, company_id)."""
|
|
company_id = "ffffffff-1111-2222-3333-444444444444"
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
method = request.method
|
|
|
|
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
|
return _json_response(200, [])
|
|
|
|
if method == "POST" and url.endswith("/companies"):
|
|
return _json_response(201, {"id": company_id, "ticker": "PLTR"})
|
|
|
|
if method == "POST" and "/sources" in url:
|
|
return _text_response(500, "Internal Server Error")
|
|
|
|
if method == "GET" and "/watchlists" in url:
|
|
return _json_response(200, [{"id": "wl-1", "active": True}])
|
|
|
|
if method == "POST" and "/members/" in url:
|
|
return _json_response(201, {})
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("PLTR", REGISTRY_BASE)
|
|
|
|
assert auto_registered is True
|
|
assert cid == company_id
|
|
|
|
async def test_source_creation_network_error_still_succeeds(self):
|
|
"""Source creation raises network error → still returns successfully."""
|
|
company_id = "11111111-aaaa-bbbb-cccc-dddddddddddd"
|
|
source_call_count = 0
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
nonlocal source_call_count
|
|
url = str(request.url)
|
|
method = request.method
|
|
|
|
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
|
return _json_response(200, [])
|
|
|
|
if method == "POST" and url.endswith("/companies"):
|
|
return _json_response(201, {"id": company_id, "ticker": "RIVN"})
|
|
|
|
if method == "POST" and "/sources" in url:
|
|
source_call_count += 1
|
|
raise httpx.ConnectError("connection refused")
|
|
|
|
if method == "GET" and "/watchlists" in url:
|
|
return _json_response(200, [{"id": "wl-1", "active": True}])
|
|
|
|
if method == "POST" and "/members/" in url:
|
|
return _json_response(201, {})
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("RIVN", REGISTRY_BASE)
|
|
|
|
assert auto_registered is True
|
|
assert cid == company_id
|
|
assert source_call_count == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 5. Watchlist failure tolerance
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestWatchlistFailureTolerance:
|
|
"""Watchlist failures are best-effort — should not block registration."""
|
|
|
|
async def test_watchlist_get_fails_still_succeeds(self):
|
|
"""GET /watchlists raises exception → still returns (True, company_id)."""
|
|
company_id = "22222222-aaaa-bbbb-cccc-dddddddddddd"
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
method = request.method
|
|
|
|
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
|
return _json_response(200, [])
|
|
|
|
if method == "POST" and url.endswith("/companies"):
|
|
return _json_response(201, {"id": company_id, "ticker": "SOFI"})
|
|
|
|
if method == "POST" and "/sources" in url:
|
|
return _json_response(201, {"id": "src-1"})
|
|
|
|
if method == "GET" and "/watchlists" in url:
|
|
raise httpx.ConnectError("connection refused")
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("SOFI", REGISTRY_BASE)
|
|
|
|
assert auto_registered is True
|
|
assert cid == company_id
|
|
|
|
async def test_watchlist_member_add_fails_still_succeeds(self):
|
|
"""POST /watchlists/{id}/members/{cid} returns 500 → still succeeds."""
|
|
company_id = "33333333-aaaa-bbbb-cccc-dddddddddddd"
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
method = request.method
|
|
|
|
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
|
return _json_response(200, [])
|
|
|
|
if method == "POST" and url.endswith("/companies"):
|
|
return _json_response(201, {"id": company_id, "ticker": "COIN"})
|
|
|
|
if method == "POST" and "/sources" in url:
|
|
return _json_response(201, {"id": "src-1"})
|
|
|
|
if method == "GET" and "/watchlists" in url:
|
|
return _json_response(200, [{"id": "wl-1", "active": True}])
|
|
|
|
if method == "POST" and "/members/" in url:
|
|
return _text_response(500, "Internal Server Error")
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("COIN", REGISTRY_BASE)
|
|
|
|
assert auto_registered is True
|
|
assert cid == company_id
|
|
|
|
async def test_ticker_is_uppercased(self):
|
|
"""Lowercase ticker input is normalized to uppercase."""
|
|
company_id = "44444444-aaaa-bbbb-cccc-dddddddddddd"
|
|
received_tickers: list[str] = []
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
url = str(request.url)
|
|
method = request.method
|
|
|
|
if method == "GET" and "/companies" in url and "/watchlists" not in url:
|
|
return _json_response(200, [])
|
|
|
|
if method == "POST" and url.endswith("/companies"):
|
|
body = json.loads(request.content)
|
|
received_tickers.append(body.get("ticker", ""))
|
|
return _json_response(201, {"id": company_id, "ticker": body.get("ticker", "")})
|
|
|
|
if method == "POST" and "/sources" in url:
|
|
return _json_response(201, {"id": "src-1"})
|
|
|
|
if method == "GET" and "/watchlists" in url:
|
|
return _json_response(200, [{"id": "wl-1", "active": True}])
|
|
|
|
if method == "POST" and "/members/" in url:
|
|
return _json_response(201, {})
|
|
|
|
return _text_response(404)
|
|
|
|
with _patched_client(handler):
|
|
auto_registered, cid = await auto_register_symbol("tsla", REGISTRY_BASE)
|
|
|
|
assert auto_registered is True
|
|
assert received_tickers == ["TSLA"]
|
|
|
|
|
|
# ===========================================================================
|
|
# Override Endpoint Tests (Task 3.1)
|
|
# ===========================================================================
|
|
#
|
|
# Unit tests for POST /api/trading/override/order endpoint.
|
|
# Requirements: 3.1, 3.2, 3.4, 3.5, 9.1
|
|
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from unittest.mock import patch as _patch
|
|
|
|
from starlette.testclient import TestClient
|
|
|
|
import services.trading.app as _app_module
|
|
from services.trading.app import app
|
|
|
|
|
|
def _make_fake_engine(redis_rpush_side_effect=None):
|
|
"""Create a fake engine object with a mock Redis client."""
|
|
fake_engine = MagicMock()
|
|
fake_engine.running = True
|
|
fake_engine.redis = AsyncMock()
|
|
fake_engine.redis.rpush = AsyncMock(side_effect=redis_rpush_side_effect)
|
|
return fake_engine
|
|
|
|
|
|
def _override_client(fake_engine):
|
|
"""Return a TestClient with the module-level engine replaced."""
|
|
original = _app_module.engine
|
|
_app_module.engine = fake_engine
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
return client, original
|
|
|
|
|
|
def _restore_engine(original):
|
|
_app_module.engine = original
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 6. Valid order returns 202 with correct response shape
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOverrideEndpointValid:
|
|
"""Requirement 3.1, 3.4: Valid order returns 202 with correct response shape."""
|
|
|
|
def test_valid_market_order_returns_202(self):
|
|
"""A valid market order returns 202 with job_id, status, ticker, side, quantity, auto_registered."""
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
|
|
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "AAPL"}])):
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "AAPL",
|
|
"side": "buy",
|
|
"quantity": 10.0,
|
|
"order_type": "market",
|
|
},
|
|
)
|
|
assert resp.status_code == 202
|
|
data = resp.json()
|
|
assert "job_id" in data
|
|
assert data["status"] == "queued"
|
|
assert data["ticker"] == "AAPL"
|
|
assert data["side"] == "buy"
|
|
assert data["quantity"] == 10.0
|
|
assert isinstance(data["auto_registered"], bool)
|
|
assert data["job_id"].startswith("override-")
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_valid_limit_order_returns_202(self):
|
|
"""A valid limit order with limit_price returns 202."""
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
|
|
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "TSLA"}])):
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "TSLA",
|
|
"side": "sell",
|
|
"quantity": 5.0,
|
|
"order_type": "limit",
|
|
"limit_price": 150.0,
|
|
},
|
|
)
|
|
assert resp.status_code == 202
|
|
data = resp.json()
|
|
assert data["ticker"] == "TSLA"
|
|
assert data["side"] == "sell"
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 7. Invalid ticker returns 422
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOverrideEndpointInvalidTicker:
|
|
"""Requirement 3.5: Invalid ticker format returns 422."""
|
|
|
|
def test_numeric_ticker_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={"ticker": "123", "side": "buy", "quantity": 1.0},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_empty_ticker_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={"ticker": "", "side": "buy", "quantity": 1.0},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_ticker_with_spaces_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={"ticker": "AA BB", "side": "buy", "quantity": 1.0},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_ticker_too_long_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={"ticker": "ABCDEFGHIJK", "side": "buy", "quantity": 1.0},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 8. Missing limit_price for limit order returns 422
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOverrideEndpointMissingLimitPrice:
|
|
"""Requirement 3.5: Missing limit_price for limit/stop_limit orders returns 422."""
|
|
|
|
def test_limit_order_without_limit_price_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "AAPL",
|
|
"side": "buy",
|
|
"quantity": 10.0,
|
|
"order_type": "limit",
|
|
},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_stop_limit_order_without_limit_price_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "AAPL",
|
|
"side": "buy",
|
|
"quantity": 10.0,
|
|
"order_type": "stop_limit",
|
|
"stop_price": 100.0,
|
|
},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 9. Missing stop_price for stop order returns 422
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOverrideEndpointMissingStopPrice:
|
|
"""Requirement 3.5: Missing stop_price for stop/stop_limit orders returns 422."""
|
|
|
|
def test_stop_order_without_stop_price_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "AAPL",
|
|
"side": "buy",
|
|
"quantity": 10.0,
|
|
"order_type": "stop",
|
|
},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_stop_limit_order_without_stop_price_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "AAPL",
|
|
"side": "buy",
|
|
"quantity": 10.0,
|
|
"order_type": "stop_limit",
|
|
"limit_price": 150.0,
|
|
},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 10. Non-positive quantity returns 422
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOverrideEndpointNonPositiveQuantity:
|
|
"""Requirement 3.5: Non-positive quantity returns 422."""
|
|
|
|
def test_zero_quantity_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={"ticker": "AAPL", "side": "buy", "quantity": 0.0},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_negative_quantity_returns_422(self):
|
|
fake_engine = _make_fake_engine()
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={"ticker": "AAPL", "side": "buy", "quantity": -5.0},
|
|
)
|
|
assert resp.status_code == 422
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# 11. Enqueued job has correct structure and source: "manual_override"
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOverrideEndpointJobStructure:
|
|
"""Requirement 3.2, 9.1: Enqueued job has correct structure and source marker."""
|
|
|
|
def test_enqueued_job_has_correct_fields(self):
|
|
"""The JSON payload RPUSH'd to Redis contains all required fields."""
|
|
captured_payloads: list[str] = []
|
|
|
|
async def capture_rpush(key, value):
|
|
captured_payloads.append(value)
|
|
|
|
fake_engine = _make_fake_engine()
|
|
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
|
|
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "MSFT"}])):
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "MSFT",
|
|
"side": "buy",
|
|
"quantity": 25.0,
|
|
"order_type": "market",
|
|
},
|
|
)
|
|
assert resp.status_code == 202
|
|
assert len(captured_payloads) == 1
|
|
|
|
payload = json.loads(captured_payloads[0])
|
|
assert payload["ticker"] == "MSFT"
|
|
assert payload["side"] == "buy"
|
|
assert payload["quantity"] == 25.0
|
|
assert payload["order_type"] == "market"
|
|
assert payload["source"] == "manual_override"
|
|
assert payload["idempotency_key"].startswith("override-")
|
|
assert payload["limit_price"] is None
|
|
assert payload["stop_price"] is None
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_enqueued_job_includes_limit_and_stop_prices(self):
|
|
"""When limit_price and stop_price are provided, they appear in the payload."""
|
|
captured_payloads: list[str] = []
|
|
|
|
async def capture_rpush(key, value):
|
|
captured_payloads.append(value)
|
|
|
|
fake_engine = _make_fake_engine()
|
|
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(True, "comp-2")):
|
|
with _patched_client(lambda req: _json_response(200, [])):
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "GOOG",
|
|
"side": "sell",
|
|
"quantity": 3.0,
|
|
"order_type": "stop_limit",
|
|
"limit_price": 140.0,
|
|
"stop_price": 135.0,
|
|
},
|
|
)
|
|
assert resp.status_code == 202
|
|
assert len(captured_payloads) == 1
|
|
|
|
payload = json.loads(captured_payloads[0])
|
|
assert payload["source"] == "manual_override"
|
|
assert payload["limit_price"] == 140.0
|
|
assert payload["stop_price"] == 135.0
|
|
finally:
|
|
_restore_engine(original)
|
|
|
|
def test_job_id_matches_response(self):
|
|
"""The job_id in the response matches the idempotency_key in the enqueued payload."""
|
|
captured_payloads: list[str] = []
|
|
|
|
async def capture_rpush(key, value):
|
|
captured_payloads.append(value)
|
|
|
|
fake_engine = _make_fake_engine()
|
|
fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush)
|
|
client, original = _override_client(fake_engine)
|
|
try:
|
|
with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")):
|
|
with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "AMD"}])):
|
|
resp = client.post(
|
|
"/api/trading/override/order",
|
|
json={
|
|
"ticker": "AMD",
|
|
"side": "buy",
|
|
"quantity": 1.0,
|
|
"order_type": "market",
|
|
},
|
|
)
|
|
assert resp.status_code == 202
|
|
data = resp.json()
|
|
payload = json.loads(captured_payloads[0])
|
|
assert data["job_id"] == payload["idempotency_key"]
|
|
finally:
|
|
_restore_engine(original)
|