"""Tests for the vLLM client, health check, config, and LLM factory.""" import json import logging from unittest.mock import patch import httpx import pytest from services.extractor.client import OllamaClient from services.extractor.llm_factory import build_llm_client from services.extractor.vllm_client import VLLMClient, check_vllm_health from services.shared.agent_config import ResolvedAgentConfig from services.shared.config import AppConfig, OllamaConfig, VLLMConfig, load_config # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _valid_extraction_json() -> str: """Minimal valid extraction result as JSON string.""" return json.dumps({ "summary": "Apple beat earnings expectations.", "companies": [ { "ticker": "AAPL", "company_name": "Apple Inc.", "relevance": 0.95, "sentiment": "positive", "impact_score": 0.7, "impact_horizon": "1d_30d", "catalyst_type": "earnings", "key_facts": ["Revenue up 12%"], "risks": [], "evidence_spans": ["Apple beat expectations"], } ], "macro_themes": ["ai_capex"], "novelty_score": 0.6, "confidence": 0.85, "extraction_warnings": [], }) def _openai_response(content: str, status: int = 200) -> httpx.Response: """Build a fake OpenAI-compatible /v1/chat/completions response.""" body = { "choices": [ {"message": {"role": "assistant", "content": content}} ], } return httpx.Response(status, json=body) def _make_vllm_config() -> VLLMConfig: return VLLMConfig( base_url="http://test-vllm:8000", model="test-vllm-model", timeout=10, max_retries=2, retry_base_delay=0.0, retry_max_delay=0.0, retry_backoff_multiplier=2.0, max_tokens=4096, temperature=0.7, api_key="", ) def _make_ollama_config() -> OllamaConfig: return OllamaConfig( base_url="http://test-ollama:11434", model="test-ollama-model", timeout=10, retry_base_delay=0.0, retry_max_delay=0.0, retry_backoff_multiplier=2.0, ) def _make_prompts() -> dict[str, str]: return {"system": "You are a helpful assistant.", "user": "Extract info."} def _make_resolved(provider: str = "vllm") -> ResolvedAgentConfig: return ResolvedAgentConfig( agent_id="agent-1", variant_id=None, model_provider=provider, model_name="resolved-model", system_prompt="sys", user_prompt_template="usr", prompt_version="v1", temperature=0.5, max_tokens=8192, context_window=0, input_token_limit=0, token_budget=0, timeout_seconds=60, max_retries=2, ) # =================================================================== # Task 7: Unit Tests for VLLMClient # =================================================================== # 7.1 — VLLMClient sends correct payload to /v1/chat/completions @pytest.mark.asyncio async def test_vllm_sends_correct_payload(): """VLLMClient sends POST to /v1/chat/completions with correct OpenAI payload.""" captured: dict = {} def handler(request: httpx.Request) -> httpx.Response: captured["url"] = str(request.url) captured["payload"] = json.loads(request.content) return _openai_response(_valid_extraction_json()) transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) config = _make_vllm_config() client = VLLMClient(config, http_client=http) await client.call_llm(_make_prompts(), {}) assert captured["url"] == "http://test-vllm:8000/v1/chat/completions" payload = captured["payload"] assert payload["model"] == "test-vllm-model" assert len(payload["messages"]) == 2 assert payload["messages"][0]["role"] == "system" assert payload["messages"][1]["role"] == "user" assert payload["max_tokens"] == 4096 assert payload["temperature"] == 0.7 await client.close() # 7.2 — VLLMClient extracts content from choices[0].message.content @pytest.mark.asyncio async def test_vllm_extracts_content_from_choices(): """VLLMClient extracts content from choices[0].message.content.""" transport = httpx.MockTransport( lambda req: _openai_response(_valid_extraction_json()) ) http = httpx.AsyncClient(transport=transport) client = VLLMClient(_make_vllm_config(), http_client=http) attempt = await client.call_llm(_make_prompts(), {}) assert attempt.raw_output == _valid_extraction_json() assert attempt.error is None assert attempt.validation is not None assert attempt.validation.valid await client.close() # 7.3 — VLLMClient handles empty choices array → empty_model_response @pytest.mark.asyncio async def test_vllm_empty_choices(): """Empty choices array returns empty_model_response error.""" body = {"choices": []} transport = httpx.MockTransport( lambda req: httpx.Response(200, json=body) ) http = httpx.AsyncClient(transport=transport) client = VLLMClient(_make_vllm_config(), http_client=http) attempt = await client.call_llm(_make_prompts(), {}) assert attempt.error == "empty_model_response" await client.close() # 7.4 — VLLMClient handles HTTP timeout → timeout error @pytest.mark.asyncio async def test_vllm_timeout(): """HTTP timeout returns 'timeout' error.""" def handler(request: httpx.Request) -> httpx.Response: raise httpx.ReadTimeout("timed out") transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) client = VLLMClient(_make_vllm_config(), http_client=http) attempt = await client.call_llm(_make_prompts(), {}) assert attempt.error == "timeout" assert attempt.duration_ms >= 0 await client.close() # 7.5 — VLLMClient handles HTTP 500 → http_500 retryable error @pytest.mark.asyncio async def test_vllm_http_500(): """HTTP 500 returns 'http_500' error marked as retryable.""" transport = httpx.MockTransport( lambda req: httpx.Response(500, text="Internal Server Error") ) http = httpx.AsyncClient(transport=transport) client = VLLMClient(_make_vllm_config(), http_client=http) attempt = await client.call_llm(_make_prompts(), {}) assert attempt.error == "http_500" assert attempt.retryable is True await client.close() # 7.6 — VLLMClient handles HTTP 400 → http_400 non-retryable error @pytest.mark.asyncio async def test_vllm_http_400(): """HTTP 400 returns 'http_400' error marked as non-retryable.""" transport = httpx.MockTransport( lambda req: httpx.Response(400, text="Bad Request") ) http = httpx.AsyncClient(transport=transport) client = VLLMClient(_make_vllm_config(), http_client=http) attempt = await client.call_llm(_make_prompts(), {}) assert attempt.error == "http_400" assert attempt.retryable is False await client.close() # 7.7 — VLLMClient handles connection error → connection_error: ... @pytest.mark.asyncio async def test_vllm_connection_error(): """Connection error returns 'connection_error: ...' error string.""" def handler(request: httpx.Request) -> httpx.Response: raise httpx.ConnectError("Connection refused") transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) client = VLLMClient(_make_vllm_config(), http_client=http) attempt = await client.call_llm(_make_prompts(), {}) assert attempt.error is not None assert attempt.error.startswith("connection_error:") await client.close() # 7.8 — VLLMClient applies markdown fence stripping and JSON repair @pytest.mark.asyncio async def test_vllm_markdown_fence_stripping_and_json_repair(): """VLLMClient strips markdown fences and repairs JSON.""" # Wrap valid JSON in markdown fences fenced = f"```json\n{_valid_extraction_json()}\n```" transport = httpx.MockTransport( lambda req: _openai_response(fenced) ) http = httpx.AsyncClient(transport=transport) client = VLLMClient(_make_vllm_config(), http_client=http) attempt = await client.call_llm(_make_prompts(), {}) # Should succeed after stripping fences assert attempt.error is None assert attempt.validation is not None assert attempt.validation.valid await client.close() # 7.9 — VLLMClient includes temperature and response_format in payload @pytest.mark.asyncio async def test_vllm_payload_includes_temperature_and_response_format(): """Payload includes temperature and response_format fields.""" captured: dict = {} def handler(request: httpx.Request) -> httpx.Response: captured["payload"] = json.loads(request.content) return _openai_response(_valid_extraction_json()) transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) config = _make_vllm_config() config.temperature = 0.3 client = VLLMClient(config, http_client=http) await client.call_llm(_make_prompts(), {}) assert captured["payload"]["temperature"] == 0.3 assert captured["payload"]["response_format"] == {"type": "json_object"} await client.close() # 7.10 — Health check success returns True and logs INFO @pytest.mark.asyncio async def test_health_check_success(caplog): """check_vllm_health returns True and logs INFO on success.""" transport = httpx.MockTransport( lambda req: httpx.Response(200, json={"data": [{"id": "model-1"}]}) ) with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)): with caplog.at_level(logging.INFO, logger="vllm_client"): result = await check_vllm_health("http://test-vllm:8000") assert result is True assert any("health check passed" in r.message for r in caplog.records) # 7.11 — Health check failure returns False and logs WARNING @pytest.mark.asyncio async def test_health_check_failure(caplog): """check_vllm_health returns False and logs WARNING on failure.""" def handler(request: httpx.Request) -> httpx.Response: raise httpx.ConnectError("Connection refused") transport = httpx.MockTransport(handler) with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)): with caplog.at_level(logging.WARNING, logger="vllm_client"): result = await check_vllm_health("http://unreachable:8000") assert result is False assert any("health check failed" in r.message for r in caplog.records) # 7.12 — OllamaClient.call_llm() delegates to _call_ollama() @pytest.mark.asyncio async def test_ollama_call_llm_delegates(): """OllamaClient.call_llm() delegates to _call_ollama().""" transport = httpx.MockTransport( lambda req: httpx.Response( 200, json={"message": {"role": "assistant", "content": _valid_extraction_json()}}, ) ) http = httpx.AsyncClient(transport=transport) config = _make_ollama_config() client = OllamaClient(config, http_client=http) prompts = _make_prompts() schema = {} # call_llm should produce the same result as _call_ollama result_llm = await client.call_llm(prompts, schema) # Both should succeed with valid extraction JSON assert result_llm.error is None assert result_llm.validation is not None assert result_llm.validation.valid assert result_llm.model == config.model await client.close() # 7.13 — VLLMConfig loading from environment variables def test_vllm_config_from_env(monkeypatch): """VLLMConfig fields are loaded from VLLM_* environment variables.""" monkeypatch.setenv("VLLM_BASE_URL", "http://custom:9000") monkeypatch.setenv("VLLM_MODEL", "custom-model") monkeypatch.setenv("VLLM_TIMEOUT", "300") monkeypatch.setenv("VLLM_MAX_RETRIES", "5") monkeypatch.setenv("VLLM_TEMPERATURE", "0.9") monkeypatch.setenv("VLLM_API_KEY", "secret-key") monkeypatch.setenv("VLLM_MAX_TOKENS", "16384") cfg = load_config() assert cfg.vllm.base_url == "http://custom:9000" assert cfg.vllm.model == "custom-model" assert cfg.vllm.timeout == 300 assert cfg.vllm.max_retries == 5 assert cfg.vllm.temperature == 0.9 assert cfg.vllm.api_key == "secret-key" assert cfg.vllm.max_tokens == 16384 # 7.14 — AppConfig includes vllm field with correct defaults def test_appconfig_vllm_defaults(): """AppConfig includes a vllm field with VLLMConfig defaults.""" cfg = AppConfig() assert hasattr(cfg, "vllm") assert isinstance(cfg.vllm, VLLMConfig) assert cfg.vllm.base_url == "http://192.168.42.254:8000" assert cfg.vllm.model == "RedHatAI/Qwen3.6-35B-A3B-NVFP4" assert cfg.vllm.timeout == 120 assert cfg.vllm.max_retries == 2 assert cfg.vllm.temperature == 0.7 assert cfg.vllm.max_tokens == 32768 assert cfg.vllm.api_key == "" # =================================================================== # Task 8: Unit Tests for LLM Factory # =================================================================== # 8.1 — Factory returns OllamaClient when provider is "ollama" @pytest.mark.asyncio async def test_factory_ollama_provider(): """build_llm_client returns OllamaClient when provider is 'ollama'.""" resolved = _make_resolved(provider="ollama") transport = httpx.MockTransport(lambda req: httpx.Response(200)) http = httpx.AsyncClient(transport=transport) client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http) assert isinstance(client, OllamaClient) await client.close() # 8.2 — Factory returns VLLMClient when provider is "vllm" @pytest.mark.asyncio async def test_factory_vllm_provider(): """build_llm_client returns VLLMClient when provider is 'vllm'.""" resolved = _make_resolved(provider="vllm") transport = httpx.MockTransport(lambda req: httpx.Response(200)) http = httpx.AsyncClient(transport=transport) client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http) assert isinstance(client, VLLMClient) await client.close() # 8.3 — Factory returns OllamaClient when provider is empty string (default) @pytest.mark.asyncio async def test_factory_empty_provider_defaults_to_ollama(): """build_llm_client returns OllamaClient when provider is empty string.""" resolved = _make_resolved(provider="") transport = httpx.MockTransport(lambda req: httpx.Response(200)) http = httpx.AsyncClient(transport=transport) client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http) assert isinstance(client, OllamaClient) await client.close() # 8.4 — Factory returns OllamaClient with warning when provider is unknown @pytest.mark.asyncio async def test_factory_unknown_provider_warns_and_falls_back(caplog): """build_llm_client logs warning and returns OllamaClient for unknown provider.""" resolved = _make_resolved(provider="unknown-provider") transport = httpx.MockTransport(lambda req: httpx.Response(200)) http = httpx.AsyncClient(transport=transport) with caplog.at_level(logging.WARNING): client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http) assert isinstance(client, OllamaClient) assert any("unknown" in r.message.lower() for r in caplog.records) await client.close()