Files
stonks-oracle/tests/test_vllm_client.py
T
Celes Renata 117b693b19 feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
2026-04-23 08:17:23 +00:00

462 lines
16 KiB
Python

"""Tests for the vLLM client, health check, config, and LLM factory."""
import json
import logging
from unittest.mock import patch
import httpx
import pytest
from services.extractor.client import OllamaClient
from services.extractor.llm_factory import build_llm_client
from services.extractor.vllm_client import VLLMClient, check_vllm_health
from services.shared.agent_config import ResolvedAgentConfig
from services.shared.config import AppConfig, OllamaConfig, VLLMConfig, load_config
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _valid_extraction_json() -> str:
"""Minimal valid extraction result as JSON string."""
return json.dumps({
"summary": "Apple beat earnings expectations.",
"companies": [
{
"ticker": "AAPL",
"company_name": "Apple Inc.",
"relevance": 0.95,
"sentiment": "positive",
"impact_score": 0.7,
"impact_horizon": "1d_30d",
"catalyst_type": "earnings",
"key_facts": ["Revenue up 12%"],
"risks": [],
"evidence_spans": ["Apple beat expectations"],
}
],
"macro_themes": ["ai_capex"],
"novelty_score": 0.6,
"confidence": 0.85,
"extraction_warnings": [],
})
def _openai_response(content: str, status: int = 200) -> httpx.Response:
"""Build a fake OpenAI-compatible /v1/chat/completions response."""
body = {
"choices": [
{"message": {"role": "assistant", "content": content}}
],
}
return httpx.Response(status, json=body)
def _make_vllm_config() -> VLLMConfig:
return VLLMConfig(
base_url="http://test-vllm:8000",
model="test-vllm-model",
timeout=10,
max_retries=2,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
max_tokens=4096,
temperature=0.7,
api_key="",
)
def _make_ollama_config() -> OllamaConfig:
return OllamaConfig(
base_url="http://test-ollama:11434",
model="test-ollama-model",
timeout=10,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
)
def _make_prompts() -> dict[str, str]:
return {"system": "You are a helpful assistant.", "user": "Extract info."}
def _make_resolved(provider: str = "vllm") -> ResolvedAgentConfig:
return ResolvedAgentConfig(
agent_id="agent-1",
variant_id=None,
model_provider=provider,
model_name="resolved-model",
system_prompt="sys",
user_prompt_template="usr",
prompt_version="v1",
temperature=0.5,
max_tokens=8192,
context_window=0,
input_token_limit=0,
token_budget=0,
timeout_seconds=60,
max_retries=2,
)
# ===================================================================
# Task 7: Unit Tests for VLLMClient
# ===================================================================
# 7.1 — VLLMClient sends correct payload to /v1/chat/completions
@pytest.mark.asyncio
async def test_vllm_sends_correct_payload():
"""VLLMClient sends POST to /v1/chat/completions with correct OpenAI payload."""
captured: dict = {}
def handler(request: httpx.Request) -> httpx.Response:
captured["url"] = str(request.url)
captured["payload"] = json.loads(request.content)
return _openai_response(_valid_extraction_json())
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
config = _make_vllm_config()
client = VLLMClient(config, http_client=http)
await client.call_llm(_make_prompts(), {})
assert captured["url"] == "http://test-vllm:8000/v1/chat/completions"
payload = captured["payload"]
assert payload["model"] == "test-vllm-model"
assert len(payload["messages"]) == 2
assert payload["messages"][0]["role"] == "system"
assert payload["messages"][1]["role"] == "user"
assert payload["max_tokens"] == 4096
assert payload["temperature"] == 0.7
await client.close()
# 7.2 — VLLMClient extracts content from choices[0].message.content
@pytest.mark.asyncio
async def test_vllm_extracts_content_from_choices():
"""VLLMClient extracts content from choices[0].message.content."""
transport = httpx.MockTransport(
lambda req: _openai_response(_valid_extraction_json())
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.raw_output == _valid_extraction_json()
assert attempt.error is None
assert attempt.validation is not None
assert attempt.validation.valid
await client.close()
# 7.3 — VLLMClient handles empty choices array → empty_model_response
@pytest.mark.asyncio
async def test_vllm_empty_choices():
"""Empty choices array returns empty_model_response error."""
body = {"choices": []}
transport = httpx.MockTransport(
lambda req: httpx.Response(200, json=body)
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error == "empty_model_response"
await client.close()
# 7.4 — VLLMClient handles HTTP timeout → timeout error
@pytest.mark.asyncio
async def test_vllm_timeout():
"""HTTP timeout returns 'timeout' error."""
def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ReadTimeout("timed out")
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error == "timeout"
assert attempt.duration_ms >= 0
await client.close()
# 7.5 — VLLMClient handles HTTP 500 → http_500 retryable error
@pytest.mark.asyncio
async def test_vllm_http_500():
"""HTTP 500 returns 'http_500' error marked as retryable."""
transport = httpx.MockTransport(
lambda req: httpx.Response(500, text="Internal Server Error")
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error == "http_500"
assert attempt.retryable is True
await client.close()
# 7.6 — VLLMClient handles HTTP 400 → http_400 non-retryable error
@pytest.mark.asyncio
async def test_vllm_http_400():
"""HTTP 400 returns 'http_400' error marked as non-retryable."""
transport = httpx.MockTransport(
lambda req: httpx.Response(400, text="Bad Request")
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error == "http_400"
assert attempt.retryable is False
await client.close()
# 7.7 — VLLMClient handles connection error → connection_error: ...
@pytest.mark.asyncio
async def test_vllm_connection_error():
"""Connection error returns 'connection_error: ...' error string."""
def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("Connection refused")
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error is not None
assert attempt.error.startswith("connection_error:")
await client.close()
# 7.8 — VLLMClient applies markdown fence stripping and JSON repair
@pytest.mark.asyncio
async def test_vllm_markdown_fence_stripping_and_json_repair():
"""VLLMClient strips markdown fences and repairs JSON."""
# Wrap valid JSON in markdown fences
fenced = f"```json\n{_valid_extraction_json()}\n```"
transport = httpx.MockTransport(
lambda req: _openai_response(fenced)
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
# Should succeed after stripping fences
assert attempt.error is None
assert attempt.validation is not None
assert attempt.validation.valid
await client.close()
# 7.9 — VLLMClient includes temperature and response_format in payload
@pytest.mark.asyncio
async def test_vllm_payload_includes_temperature_and_response_format():
"""Payload includes temperature and response_format fields."""
captured: dict = {}
def handler(request: httpx.Request) -> httpx.Response:
captured["payload"] = json.loads(request.content)
return _openai_response(_valid_extraction_json())
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
config = _make_vllm_config()
config.temperature = 0.3
client = VLLMClient(config, http_client=http)
await client.call_llm(_make_prompts(), {})
assert captured["payload"]["temperature"] == 0.3
assert captured["payload"]["response_format"] == {"type": "json_object"}
await client.close()
# 7.10 — Health check success returns True and logs INFO
@pytest.mark.asyncio
async def test_health_check_success(caplog):
"""check_vllm_health returns True and logs INFO on success."""
transport = httpx.MockTransport(
lambda req: httpx.Response(200, json={"data": [{"id": "model-1"}]})
)
with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)):
with caplog.at_level(logging.INFO, logger="vllm_client"):
result = await check_vllm_health("http://test-vllm:8000")
assert result is True
assert any("health check passed" in r.message for r in caplog.records)
# 7.11 — Health check failure returns False and logs WARNING
@pytest.mark.asyncio
async def test_health_check_failure(caplog):
"""check_vllm_health returns False and logs WARNING on failure."""
def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("Connection refused")
transport = httpx.MockTransport(handler)
with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)):
with caplog.at_level(logging.WARNING, logger="vllm_client"):
result = await check_vllm_health("http://unreachable:8000")
assert result is False
assert any("health check failed" in r.message for r in caplog.records)
# 7.12 — OllamaClient.call_llm() delegates to _call_ollama()
@pytest.mark.asyncio
async def test_ollama_call_llm_delegates():
"""OllamaClient.call_llm() delegates to _call_ollama()."""
transport = httpx.MockTransport(
lambda req: httpx.Response(
200,
json={"message": {"role": "assistant", "content": _valid_extraction_json()}},
)
)
http = httpx.AsyncClient(transport=transport)
config = _make_ollama_config()
client = OllamaClient(config, http_client=http)
prompts = _make_prompts()
schema = {}
# call_llm should produce the same result as _call_ollama
result_llm = await client.call_llm(prompts, schema)
# Both should succeed with valid extraction JSON
assert result_llm.error is None
assert result_llm.validation is not None
assert result_llm.validation.valid
assert result_llm.model == config.model
await client.close()
# 7.13 — VLLMConfig loading from environment variables
def test_vllm_config_from_env(monkeypatch):
"""VLLMConfig fields are loaded from VLLM_* environment variables."""
monkeypatch.setenv("VLLM_BASE_URL", "http://custom:9000")
monkeypatch.setenv("VLLM_MODEL", "custom-model")
monkeypatch.setenv("VLLM_TIMEOUT", "300")
monkeypatch.setenv("VLLM_MAX_RETRIES", "5")
monkeypatch.setenv("VLLM_TEMPERATURE", "0.9")
monkeypatch.setenv("VLLM_API_KEY", "secret-key")
monkeypatch.setenv("VLLM_MAX_TOKENS", "16384")
cfg = load_config()
assert cfg.vllm.base_url == "http://custom:9000"
assert cfg.vllm.model == "custom-model"
assert cfg.vllm.timeout == 300
assert cfg.vllm.max_retries == 5
assert cfg.vllm.temperature == 0.9
assert cfg.vllm.api_key == "secret-key"
assert cfg.vllm.max_tokens == 16384
# 7.14 — AppConfig includes vllm field with correct defaults
def test_appconfig_vllm_defaults():
"""AppConfig includes a vllm field with VLLMConfig defaults."""
cfg = AppConfig()
assert hasattr(cfg, "vllm")
assert isinstance(cfg.vllm, VLLMConfig)
assert cfg.vllm.base_url == "http://192.168.42.254:8000"
assert cfg.vllm.model == "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
assert cfg.vllm.timeout == 120
assert cfg.vllm.max_retries == 2
assert cfg.vllm.temperature == 0.7
assert cfg.vllm.max_tokens == 32768
assert cfg.vllm.api_key == ""
# ===================================================================
# Task 8: Unit Tests for LLM Factory
# ===================================================================
# 8.1 — Factory returns OllamaClient when provider is "ollama"
@pytest.mark.asyncio
async def test_factory_ollama_provider():
"""build_llm_client returns OllamaClient when provider is 'ollama'."""
resolved = _make_resolved(provider="ollama")
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
assert isinstance(client, OllamaClient)
await client.close()
# 8.2 — Factory returns VLLMClient when provider is "vllm"
@pytest.mark.asyncio
async def test_factory_vllm_provider():
"""build_llm_client returns VLLMClient when provider is 'vllm'."""
resolved = _make_resolved(provider="vllm")
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
assert isinstance(client, VLLMClient)
await client.close()
# 8.3 — Factory returns OllamaClient when provider is empty string (default)
@pytest.mark.asyncio
async def test_factory_empty_provider_defaults_to_ollama():
"""build_llm_client returns OllamaClient when provider is empty string."""
resolved = _make_resolved(provider="")
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
assert isinstance(client, OllamaClient)
await client.close()
# 8.4 — Factory returns OllamaClient with warning when provider is unknown
@pytest.mark.asyncio
async def test_factory_unknown_provider_warns_and_falls_back(caplog):
"""build_llm_client logs warning and returns OllamaClient for unknown provider."""
resolved = _make_resolved(provider="unknown-provider")
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
with caplog.at_level(logging.WARNING):
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
assert isinstance(client, OllamaClient)
assert any("unknown" in r.message.lower() for r in caplog.records)
await client.close()