359 lines
14 KiB
Python
359 lines
14 KiB
Python
"""Unit tests for competitive API endpoints.
|
|
|
|
Tests competitor CRUD endpoints, pattern query endpoints, competitive toggle,
|
|
and auto-inference endpoint return correct data and error codes.
|
|
|
|
Requirements: 1.4, 2.5, 6.5, 8.1, 8.2, 8.5, 10.1, 10.4
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from services.api.app import _row_to_dict, app
|
|
|
|
NOW = datetime(2026, 6, 10, 12, 0, 0, tzinfo=timezone.utc)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class FakeRecord(dict):
|
|
"""Mimics asyncpg.Record for testing."""
|
|
def items(self):
|
|
return super().items()
|
|
|
|
|
|
def _make_pattern_row(ticker: str = "AAPL") -> FakeRecord:
|
|
return FakeRecord({
|
|
"id": str(uuid4()),
|
|
"source_document_id": str(uuid4()),
|
|
"source_ticker": ticker,
|
|
"target_ticker": ticker,
|
|
"catalyst_type": "earnings",
|
|
"pattern_confidence": 0.65,
|
|
"signal_direction": "bullish",
|
|
"signal_strength": 0.5,
|
|
"relationship_strength": 0.8,
|
|
"computed_at": NOW,
|
|
})
|
|
|
|
|
|
def _make_competitive_signal_row(
|
|
source_ticker: str = "MSFT",
|
|
target_ticker: str = "AAPL",
|
|
) -> FakeRecord:
|
|
return FakeRecord({
|
|
"id": str(uuid4()),
|
|
"source_document_id": str(uuid4()),
|
|
"source_ticker": source_ticker,
|
|
"target_ticker": target_ticker,
|
|
"catalyst_type": "product_launch",
|
|
"pattern_confidence": 0.55,
|
|
"signal_direction": "bearish",
|
|
"signal_strength": 0.4,
|
|
"relationship_strength": 0.7,
|
|
"computed_at": NOW,
|
|
})
|
|
|
|
|
|
def _make_decision_row(ticker: str = "AAPL") -> FakeRecord:
|
|
return FakeRecord({
|
|
"id": str(uuid4()),
|
|
"document_id": str(uuid4()),
|
|
"ticker": ticker,
|
|
"catalyst_type": "m_and_a",
|
|
"summary": "Acquisition of XYZ Corp",
|
|
"impact_score": 0.8,
|
|
"created_at": NOW,
|
|
"published_at": NOW - __import__("datetime").timedelta(days=5),
|
|
})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Route structure tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCompetitiveRouteStructure:
|
|
"""Verify all competitive-related routes are registered."""
|
|
|
|
def test_competitive_status_route_exists(self):
|
|
paths = [route.path for route in app.routes]
|
|
assert "/api/admin/competitive/status" in paths
|
|
|
|
def test_competitive_toggle_route_exists(self):
|
|
paths = [route.path for route in app.routes]
|
|
assert "/api/admin/competitive/toggle" in paths
|
|
|
|
def test_patterns_route_exists(self):
|
|
paths = [route.path for route in app.routes]
|
|
assert "/api/patterns/{ticker}" in paths
|
|
|
|
def test_competitor_patterns_route_exists(self):
|
|
paths = [route.path for route in app.routes]
|
|
assert "/api/patterns/{ticker}/competitors" in paths
|
|
|
|
def test_competitive_signals_route_exists(self):
|
|
paths = [route.path for route in app.routes]
|
|
assert "/api/patterns/{ticker}/competitive-signals" in paths
|
|
|
|
def test_decisions_route_exists(self):
|
|
paths = [route.path for route in app.routes]
|
|
assert "/api/patterns/{ticker}/decisions" in paths
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Competitive toggle endpoint (Requirement: 6.5)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCompetitiveToggleEndpoint:
|
|
"""Test competitive toggle endpoint persists state and records audit event."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_competitive_status_returns_default(self):
|
|
"""GET /api/admin/competitive/status should return default enabled state."""
|
|
mock_pool = AsyncMock()
|
|
mock_pool.fetchrow = AsyncMock(return_value=None)
|
|
|
|
with patch("services.api.app.pool", mock_pool):
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/admin/competitive/status")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["competitive_enabled"] is True
|
|
assert data["source"] == "default"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_competitive_status_from_config(self):
|
|
"""GET /api/admin/competitive/status should read from risk_configs."""
|
|
mock_pool = AsyncMock()
|
|
mock_pool.fetchrow = AsyncMock(return_value=FakeRecord({
|
|
"competitive_enabled": "false",
|
|
}))
|
|
|
|
with patch("services.api.app.pool", mock_pool):
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/admin/competitive/status")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["competitive_enabled"] is False
|
|
assert data["source"] == "risk_configs"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_toggle_competitive_layer(self):
|
|
"""PUT /api/admin/competitive/toggle should persist state and record audit."""
|
|
config_id = str(uuid4())
|
|
mock_pool = AsyncMock()
|
|
mock_pool.fetchrow = AsyncMock(return_value=FakeRecord({
|
|
"id": config_id,
|
|
"competitive_enabled": "true",
|
|
}))
|
|
mock_pool.execute = AsyncMock()
|
|
|
|
with patch("services.api.app.pool", mock_pool), \
|
|
patch("services.api.app.record_audit_event", new_callable=AsyncMock) as mock_audit:
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.put(
|
|
"/api/admin/competitive/toggle",
|
|
json={"enabled": False, "operator": "test_user"},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["competitive_enabled"] is False
|
|
assert data["previous_enabled"] is True
|
|
assert data["toggled_by"] == "test_user"
|
|
|
|
# Verify audit event was recorded
|
|
mock_audit.assert_called_once()
|
|
audit_call = mock_audit.call_args
|
|
assert audit_call.kwargs.get("event_type") or audit_call.args[1] == "competitive.layer_toggled"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_toggle_competitive_layer_enable(self):
|
|
"""PUT /api/admin/competitive/toggle should enable the layer."""
|
|
config_id = str(uuid4())
|
|
mock_pool = AsyncMock()
|
|
mock_pool.fetchrow = AsyncMock(return_value=FakeRecord({
|
|
"id": config_id,
|
|
"competitive_enabled": "false",
|
|
}))
|
|
mock_pool.execute = AsyncMock()
|
|
|
|
with patch("services.api.app.pool", mock_pool), \
|
|
patch("services.api.app.record_audit_event", new_callable=AsyncMock):
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.put(
|
|
"/api/admin/competitive/toggle",
|
|
json={"enabled": True, "operator": "admin"},
|
|
)
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["competitive_enabled"] is True
|
|
assert data["previous_enabled"] is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pattern query endpoints (Requirements: 10.1, 10.4)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPatternQueryEndpoints:
|
|
"""Test pattern query endpoints return correct data with filtering."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_competitive_signals_for_ticker(self):
|
|
"""GET /api/patterns/{ticker}/competitive-signals should return signals."""
|
|
signal_row = _make_competitive_signal_row()
|
|
mock_pool = AsyncMock()
|
|
mock_pool.fetch = AsyncMock(return_value=[signal_row])
|
|
|
|
with patch("services.api.app.pool", mock_pool):
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/patterns/AAPL/competitive-signals")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["ticker"] == "AAPL"
|
|
assert data["count"] == 1
|
|
assert len(data["competitive_signals"]) == 1
|
|
sig = data["competitive_signals"][0]
|
|
assert sig["source_ticker"] == "MSFT"
|
|
assert sig["target_ticker"] == "AAPL"
|
|
assert sig["signal_direction"] == "bearish"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_competitive_signals_empty(self):
|
|
"""GET /api/patterns/{ticker}/competitive-signals with no data returns empty."""
|
|
mock_pool = AsyncMock()
|
|
mock_pool.fetch = AsyncMock(return_value=[])
|
|
|
|
with patch("services.api.app.pool", mock_pool):
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/patterns/UNKNOWN/competitive-signals")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["count"] == 0
|
|
assert data["competitive_signals"] == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_patterns_with_catalyst_filter(self):
|
|
"""GET /api/patterns/{ticker}?catalyst_type=earnings should filter."""
|
|
mock_pool = AsyncMock()
|
|
# find_self_patterns is called with the pool — mock it at module level
|
|
with patch("services.api.app.pool", mock_pool), \
|
|
patch("services.api.app.find_self_patterns", new_callable=AsyncMock) as mock_find:
|
|
mock_find.return_value = []
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/patterns/AAPL?catalyst_type=earnings")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["ticker"] == "AAPL"
|
|
assert data["count"] == 0
|
|
# Verify find_self_patterns was called with the catalyst_type
|
|
mock_find.assert_called_once()
|
|
call_args = mock_find.call_args
|
|
assert call_args.args[1] == "AAPL"
|
|
assert call_args.args[2] == "earnings"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_patterns_without_filter_queries_all_catalysts(self):
|
|
"""GET /api/patterns/{ticker} without filter queries all catalyst types."""
|
|
mock_pool = AsyncMock()
|
|
# Return one catalyst type from the distinct query
|
|
mock_pool.fetch = AsyncMock(return_value=[
|
|
FakeRecord({"catalyst_type": "earnings"}),
|
|
])
|
|
|
|
with patch("services.api.app.pool", mock_pool), \
|
|
patch("services.api.app.find_self_patterns", new_callable=AsyncMock) as mock_find:
|
|
mock_find.return_value = []
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/patterns/AAPL")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["ticker"] == "AAPL"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_decisions_returns_major_decisions(self):
|
|
"""GET /api/patterns/{ticker}/decisions should return major decisions."""
|
|
decision_row = _make_decision_row()
|
|
mock_pool = AsyncMock()
|
|
mock_pool.fetch = AsyncMock(return_value=[decision_row])
|
|
|
|
with patch("services.api.app.pool", mock_pool), \
|
|
patch("services.api.app.find_self_patterns", new_callable=AsyncMock) as mock_find:
|
|
mock_find.return_value = []
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/patterns/AAPL/decisions")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["ticker"] == "AAPL"
|
|
assert data["count"] == 1
|
|
assert data["decisions"][0]["catalyst_type"] == "m_and_a"
|
|
assert "pattern_statistics" in data["decisions"][0]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_decisions_empty(self):
|
|
"""GET /api/patterns/{ticker}/decisions with no data returns empty."""
|
|
mock_pool = AsyncMock()
|
|
mock_pool.fetch = AsyncMock(return_value=[])
|
|
|
|
with patch("services.api.app.pool", mock_pool):
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/patterns/UNKNOWN/decisions")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["count"] == 0
|
|
assert data["decisions"] == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_competitor_patterns(self):
|
|
"""GET /api/patterns/{ticker}/competitors should return cross-company patterns."""
|
|
mock_pool = AsyncMock()
|
|
# First fetch: competitor tickers
|
|
# Second fetch: catalyst types
|
|
mock_pool.fetch = AsyncMock(side_effect=[
|
|
[FakeRecord({"competitor_ticker": "MSFT"})],
|
|
[FakeRecord({"catalyst_type": "earnings"})],
|
|
])
|
|
|
|
with patch("services.api.app.pool", mock_pool), \
|
|
patch("services.api.app.find_cross_company_patterns", new_callable=AsyncMock) as mock_cross:
|
|
mock_cross.return_value = []
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
resp = await client.get("/api/patterns/AAPL/competitors")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["ticker"] == "AAPL"
|
|
assert "cross_company_patterns" in data
|