"""Unit tests for auto_register_symbol() in services/trading/override.py. Validates: - New symbol registration (company + sources + watchlist) - Existing symbol skip - 409 conflict handling - Source/watchlist failure tolerance Requirements: 4.1, 4.2, 4.3, 4.5, 4.6 """ from __future__ import annotations import json from unittest.mock import patch import httpx import pytest from services.trading.override import auto_register_symbol REGISTRY_BASE = "http://symbol-registry:8000" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _json_response(status_code: int, json_data) -> httpx.Response: """Build an httpx.Response with JSON body for MockTransport.""" return httpx.Response( status_code, content=json.dumps(json_data).encode(), headers={"content-type": "application/json"}, ) def _text_response(status_code: int, text: str = "") -> httpx.Response: """Build an httpx.Response with text body for MockTransport.""" return httpx.Response(status_code, text=text) def _patched_client(handler): """Return a patch context manager that replaces httpx.AsyncClient with a mock-transport client. The handler is a sync callable: ``(request: httpx.Request) -> httpx.Response``. """ original_init = httpx.AsyncClient.__init__ def patched_init(self, *args, **kwargs): kwargs.pop("timeout", None) kwargs["transport"] = httpx.MockTransport(handler) original_init(self, *args, **kwargs) return patch.object(httpx.AsyncClient, "__init__", patched_init) # --------------------------------------------------------------------------- # 1. New symbol registration — full happy path # --------------------------------------------------------------------------- @pytest.mark.asyncio class TestAutoRegisterNewSymbol: """Requirement 4.1, 4.2, 4.3: New symbol creates company, sources, watchlist membership.""" async def test_new_symbol_full_registration(self): """GET /companies returns empty → POST /companies 201 → sources → watchlist → (True, id).""" company_id = "aaaaaaaa-1111-2222-3333-444444444444" def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) method = request.method if method == "GET" and "/companies" in url and "/watchlists" not in url: return _json_response(200, []) if method == "POST" and url.endswith("/companies"): return _json_response(201, {"id": company_id, "ticker": "TSLA"}) if method == "POST" and "/sources" in url: return _json_response(201, {"id": "src-1"}) if method == "GET" and "/watchlists" in url: return _json_response(200, [{"id": "wl-1", "name": "Default", "active": True}]) if method == "POST" and "/members/" in url: return _json_response(201, {}) return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("TSLA", REGISTRY_BASE) assert auto_registered is True assert cid == company_id async def test_new_symbol_creates_watchlist_when_none_exist(self): """When no active watchlists exist, creates 'Manual Overrides' watchlist.""" company_id = "bbbbbbbb-1111-2222-3333-444444444444" watchlist_id = "wl-new-1" call_log: list[str] = [] def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) method = request.method call_log.append(f"{method} {url}") if method == "GET" and "/companies" in url and "/watchlists" not in url: return _json_response(200, []) if method == "POST" and url.endswith("/companies"): return _json_response(201, {"id": company_id, "ticker": "NVDA"}) if method == "POST" and "/sources" in url: return _json_response(201, {"id": "src-1"}) if method == "GET" and "/watchlists" in url: return _json_response(200, []) # No active watchlists if method == "POST" and url.endswith("/watchlists"): return _json_response(201, {"id": watchlist_id, "name": "Manual Overrides"}) if method == "POST" and "/members/" in url: return _json_response(201, {}) return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("NVDA", REGISTRY_BASE) assert auto_registered is True assert cid == company_id assert any("POST" in c and "/watchlists" in c and "/members" not in c for c in call_log) # --------------------------------------------------------------------------- # 2. Existing symbol skip # --------------------------------------------------------------------------- @pytest.mark.asyncio class TestAutoRegisterExistingSymbol: """Requirement 4.5: Existing symbols skip registration.""" async def test_existing_symbol_returns_false(self): """GET /companies returns list containing the ticker → (False, company_id).""" existing_id = "cccccccc-1111-2222-3333-444444444444" def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) if request.method == "GET" and "/companies" in url: return _json_response(200, [ {"id": existing_id, "ticker": "AAPL", "legal_name": "Apple Inc"}, {"id": "other-id", "ticker": "MSFT", "legal_name": "Microsoft"}, ]) return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("AAPL", REGISTRY_BASE) assert auto_registered is False assert cid == existing_id async def test_existing_symbol_no_post_calls(self): """When ticker exists, no POST calls should be made.""" existing_id = "dddddddd-1111-2222-3333-444444444444" post_calls: list[str] = [] def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) if request.method == "POST": post_calls.append(url) if request.method == "GET" and "/companies" in url: return _json_response(200, [{"id": existing_id, "ticker": "GOOG"}]) return _text_response(404) with _patched_client(handler): await auto_register_symbol("GOOG", REGISTRY_BASE) assert post_calls == [], f"Unexpected POST calls: {post_calls}" # --------------------------------------------------------------------------- # 3. 409 conflict handling # --------------------------------------------------------------------------- @pytest.mark.asyncio class TestAutoRegister409Conflict: """Requirement 4.6: 409 conflict treated as success.""" async def test_409_conflict_fetches_existing_company(self): """POST /companies 409 → fetches existing company → returns with company_id.""" existing_id = "eeeeeeee-1111-2222-3333-444444444444" get_call_count = 0 def handler(request: httpx.Request) -> httpx.Response: nonlocal get_call_count url = str(request.url) method = request.method if method == "GET" and "/companies" in url and "/watchlists" not in url: get_call_count += 1 if get_call_count == 1: return _json_response(200, []) else: return _json_response(200, [{"id": existing_id, "ticker": "AMD"}]) if method == "POST" and url.endswith("/companies"): return _json_response(409, {"detail": "Company already exists"}) if method == "POST" and "/sources" in url: return _json_response(201, {"id": "src-1"}) if method == "GET" and "/watchlists" in url: return _json_response(200, [{"id": "wl-1", "active": True}]) if method == "POST" and "/members/" in url: return _json_response(201, {}) return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("AMD", REGISTRY_BASE) assert auto_registered is True assert cid == existing_id async def test_409_conflict_but_cannot_find_company_returns_false(self): """POST /companies 409 but re-fetch can't find company → (False, '').""" def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) if request.method == "GET" and "/companies" in url: return _json_response(200, []) if request.method == "POST" and url.endswith("/companies"): return _json_response(409, {"detail": "Conflict"}) return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("XYZ", REGISTRY_BASE) assert auto_registered is False assert cid == "" # --------------------------------------------------------------------------- # 4. Source creation failure tolerance # --------------------------------------------------------------------------- @pytest.mark.asyncio class TestSourceFailureTolerance: """Source creation failures are best-effort — should not block registration.""" async def test_source_creation_500_still_succeeds(self): """POST /companies/{id}/sources returns 500 → still returns (True, company_id).""" company_id = "ffffffff-1111-2222-3333-444444444444" def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) method = request.method if method == "GET" and "/companies" in url and "/watchlists" not in url: return _json_response(200, []) if method == "POST" and url.endswith("/companies"): return _json_response(201, {"id": company_id, "ticker": "PLTR"}) if method == "POST" and "/sources" in url: return _text_response(500, "Internal Server Error") if method == "GET" and "/watchlists" in url: return _json_response(200, [{"id": "wl-1", "active": True}]) if method == "POST" and "/members/" in url: return _json_response(201, {}) return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("PLTR", REGISTRY_BASE) assert auto_registered is True assert cid == company_id async def test_source_creation_network_error_still_succeeds(self): """Source creation raises network error → still returns successfully.""" company_id = "11111111-aaaa-bbbb-cccc-dddddddddddd" source_call_count = 0 def handler(request: httpx.Request) -> httpx.Response: nonlocal source_call_count url = str(request.url) method = request.method if method == "GET" and "/companies" in url and "/watchlists" not in url: return _json_response(200, []) if method == "POST" and url.endswith("/companies"): return _json_response(201, {"id": company_id, "ticker": "RIVN"}) if method == "POST" and "/sources" in url: source_call_count += 1 raise httpx.ConnectError("connection refused") if method == "GET" and "/watchlists" in url: return _json_response(200, [{"id": "wl-1", "active": True}]) if method == "POST" and "/members/" in url: return _json_response(201, {}) return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("RIVN", REGISTRY_BASE) assert auto_registered is True assert cid == company_id assert source_call_count == 2 # --------------------------------------------------------------------------- # 5. Watchlist failure tolerance # --------------------------------------------------------------------------- @pytest.mark.asyncio class TestWatchlistFailureTolerance: """Watchlist failures are best-effort — should not block registration.""" async def test_watchlist_get_fails_still_succeeds(self): """GET /watchlists raises exception → still returns (True, company_id).""" company_id = "22222222-aaaa-bbbb-cccc-dddddddddddd" def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) method = request.method if method == "GET" and "/companies" in url and "/watchlists" not in url: return _json_response(200, []) if method == "POST" and url.endswith("/companies"): return _json_response(201, {"id": company_id, "ticker": "SOFI"}) if method == "POST" and "/sources" in url: return _json_response(201, {"id": "src-1"}) if method == "GET" and "/watchlists" in url: raise httpx.ConnectError("connection refused") return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("SOFI", REGISTRY_BASE) assert auto_registered is True assert cid == company_id async def test_watchlist_member_add_fails_still_succeeds(self): """POST /watchlists/{id}/members/{cid} returns 500 → still succeeds.""" company_id = "33333333-aaaa-bbbb-cccc-dddddddddddd" def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) method = request.method if method == "GET" and "/companies" in url and "/watchlists" not in url: return _json_response(200, []) if method == "POST" and url.endswith("/companies"): return _json_response(201, {"id": company_id, "ticker": "COIN"}) if method == "POST" and "/sources" in url: return _json_response(201, {"id": "src-1"}) if method == "GET" and "/watchlists" in url: return _json_response(200, [{"id": "wl-1", "active": True}]) if method == "POST" and "/members/" in url: return _text_response(500, "Internal Server Error") return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("COIN", REGISTRY_BASE) assert auto_registered is True assert cid == company_id async def test_ticker_is_uppercased(self): """Lowercase ticker input is normalized to uppercase.""" company_id = "44444444-aaaa-bbbb-cccc-dddddddddddd" received_tickers: list[str] = [] def handler(request: httpx.Request) -> httpx.Response: url = str(request.url) method = request.method if method == "GET" and "/companies" in url and "/watchlists" not in url: return _json_response(200, []) if method == "POST" and url.endswith("/companies"): body = json.loads(request.content) received_tickers.append(body.get("ticker", "")) return _json_response(201, {"id": company_id, "ticker": body.get("ticker", "")}) if method == "POST" and "/sources" in url: return _json_response(201, {"id": "src-1"}) if method == "GET" and "/watchlists" in url: return _json_response(200, [{"id": "wl-1", "active": True}]) if method == "POST" and "/members/" in url: return _json_response(201, {}) return _text_response(404) with _patched_client(handler): auto_registered, cid = await auto_register_symbol("tsla", REGISTRY_BASE) assert auto_registered is True assert received_tickers == ["TSLA"] # =========================================================================== # Override Endpoint Tests (Task 3.1) # =========================================================================== # # Unit tests for POST /api/trading/override/order endpoint. # Requirements: 3.1, 3.2, 3.4, 3.5, 9.1 from unittest.mock import AsyncMock, MagicMock from unittest.mock import patch as _patch from starlette.testclient import TestClient import services.trading.app as _app_module from services.trading.app import app def _make_fake_engine(redis_rpush_side_effect=None): """Create a fake engine object with a mock Redis client.""" fake_engine = MagicMock() fake_engine.running = True fake_engine.redis = AsyncMock() fake_engine.redis.rpush = AsyncMock(side_effect=redis_rpush_side_effect) return fake_engine def _override_client(fake_engine): """Return a TestClient with the module-level engine replaced.""" original = _app_module.engine _app_module.engine = fake_engine client = TestClient(app, raise_server_exceptions=False) return client, original def _restore_engine(original): _app_module.engine = original # --------------------------------------------------------------------------- # 6. Valid order returns 202 with correct response shape # --------------------------------------------------------------------------- class TestOverrideEndpointValid: """Requirement 3.1, 3.4: Valid order returns 202 with correct response shape.""" def test_valid_market_order_returns_202(self): """A valid market order returns 202 with job_id, status, ticker, side, quantity, auto_registered.""" fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")): with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "AAPL"}])): resp = client.post( "/api/trading/override/order", json={ "ticker": "AAPL", "side": "buy", "quantity": 10.0, "order_type": "market", }, ) assert resp.status_code == 202 data = resp.json() assert "job_id" in data assert data["status"] == "queued" assert data["ticker"] == "AAPL" assert data["side"] == "buy" assert data["quantity"] == 10.0 assert isinstance(data["auto_registered"], bool) assert data["job_id"].startswith("override-") finally: _restore_engine(original) def test_valid_limit_order_returns_202(self): """A valid limit order with limit_price returns 202.""" fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")): with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "TSLA"}])): resp = client.post( "/api/trading/override/order", json={ "ticker": "TSLA", "side": "sell", "quantity": 5.0, "order_type": "limit", "limit_price": 150.0, }, ) assert resp.status_code == 202 data = resp.json() assert data["ticker"] == "TSLA" assert data["side"] == "sell" finally: _restore_engine(original) # --------------------------------------------------------------------------- # 7. Invalid ticker returns 422 # --------------------------------------------------------------------------- class TestOverrideEndpointInvalidTicker: """Requirement 3.5: Invalid ticker format returns 422.""" def test_numeric_ticker_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={"ticker": "123", "side": "buy", "quantity": 1.0}, ) assert resp.status_code == 422 finally: _restore_engine(original) def test_empty_ticker_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={"ticker": "", "side": "buy", "quantity": 1.0}, ) assert resp.status_code == 422 finally: _restore_engine(original) def test_ticker_with_spaces_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={"ticker": "AA BB", "side": "buy", "quantity": 1.0}, ) assert resp.status_code == 422 finally: _restore_engine(original) def test_ticker_too_long_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={"ticker": "ABCDEFGHIJK", "side": "buy", "quantity": 1.0}, ) assert resp.status_code == 422 finally: _restore_engine(original) # --------------------------------------------------------------------------- # 8. Missing limit_price for limit order returns 422 # --------------------------------------------------------------------------- class TestOverrideEndpointMissingLimitPrice: """Requirement 3.5: Missing limit_price for limit/stop_limit orders returns 422.""" def test_limit_order_without_limit_price_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={ "ticker": "AAPL", "side": "buy", "quantity": 10.0, "order_type": "limit", }, ) assert resp.status_code == 422 finally: _restore_engine(original) def test_stop_limit_order_without_limit_price_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={ "ticker": "AAPL", "side": "buy", "quantity": 10.0, "order_type": "stop_limit", "stop_price": 100.0, }, ) assert resp.status_code == 422 finally: _restore_engine(original) # --------------------------------------------------------------------------- # 9. Missing stop_price for stop order returns 422 # --------------------------------------------------------------------------- class TestOverrideEndpointMissingStopPrice: """Requirement 3.5: Missing stop_price for stop/stop_limit orders returns 422.""" def test_stop_order_without_stop_price_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={ "ticker": "AAPL", "side": "buy", "quantity": 10.0, "order_type": "stop", }, ) assert resp.status_code == 422 finally: _restore_engine(original) def test_stop_limit_order_without_stop_price_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={ "ticker": "AAPL", "side": "buy", "quantity": 10.0, "order_type": "stop_limit", "limit_price": 150.0, }, ) assert resp.status_code == 422 finally: _restore_engine(original) # --------------------------------------------------------------------------- # 10. Non-positive quantity returns 422 # --------------------------------------------------------------------------- class TestOverrideEndpointNonPositiveQuantity: """Requirement 3.5: Non-positive quantity returns 422.""" def test_zero_quantity_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={"ticker": "AAPL", "side": "buy", "quantity": 0.0}, ) assert resp.status_code == 422 finally: _restore_engine(original) def test_negative_quantity_returns_422(self): fake_engine = _make_fake_engine() client, original = _override_client(fake_engine) try: resp = client.post( "/api/trading/override/order", json={"ticker": "AAPL", "side": "buy", "quantity": -5.0}, ) assert resp.status_code == 422 finally: _restore_engine(original) # --------------------------------------------------------------------------- # 11. Enqueued job has correct structure and source: "manual_override" # --------------------------------------------------------------------------- class TestOverrideEndpointJobStructure: """Requirement 3.2, 9.1: Enqueued job has correct structure and source marker.""" def test_enqueued_job_has_correct_fields(self): """The JSON payload RPUSH'd to Redis contains all required fields.""" captured_payloads: list[str] = [] async def capture_rpush(key, value): captured_payloads.append(value) fake_engine = _make_fake_engine() fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush) client, original = _override_client(fake_engine) try: with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")): with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "MSFT"}])): resp = client.post( "/api/trading/override/order", json={ "ticker": "MSFT", "side": "buy", "quantity": 25.0, "order_type": "market", }, ) assert resp.status_code == 202 assert len(captured_payloads) == 1 payload = json.loads(captured_payloads[0]) assert payload["ticker"] == "MSFT" assert payload["side"] == "buy" assert payload["quantity"] == 25.0 assert payload["order_type"] == "market" assert payload["source"] == "manual_override" assert payload["idempotency_key"].startswith("override-") assert payload["limit_price"] is None assert payload["stop_price"] is None finally: _restore_engine(original) def test_enqueued_job_includes_limit_and_stop_prices(self): """When limit_price and stop_price are provided, they appear in the payload.""" captured_payloads: list[str] = [] async def capture_rpush(key, value): captured_payloads.append(value) fake_engine = _make_fake_engine() fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush) client, original = _override_client(fake_engine) try: with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(True, "comp-2")): with _patched_client(lambda req: _json_response(200, [])): resp = client.post( "/api/trading/override/order", json={ "ticker": "GOOG", "side": "sell", "quantity": 3.0, "order_type": "stop_limit", "limit_price": 140.0, "stop_price": 135.0, }, ) assert resp.status_code == 202 assert len(captured_payloads) == 1 payload = json.loads(captured_payloads[0]) assert payload["source"] == "manual_override" assert payload["limit_price"] == 140.0 assert payload["stop_price"] == 135.0 finally: _restore_engine(original) def test_job_id_matches_response(self): """The job_id in the response matches the idempotency_key in the enqueued payload.""" captured_payloads: list[str] = [] async def capture_rpush(key, value): captured_payloads.append(value) fake_engine = _make_fake_engine() fake_engine.redis.rpush = AsyncMock(side_effect=capture_rpush) client, original = _override_client(fake_engine) try: with _patch.object(_app_module, "auto_register_symbol", new_callable=AsyncMock, return_value=(False, "comp-1")): with _patched_client(lambda req: _json_response(200, [{"id": "comp-1", "ticker": "AMD"}])): resp = client.post( "/api/trading/override/order", json={ "ticker": "AMD", "side": "buy", "quantity": 1.0, "order_type": "market", }, ) assert resp.status_code == 202 data = resp.json() payload = json.loads(captured_payloads[0]) assert data["job_id"] == payload["idempotency_key"] finally: _restore_engine(original)