feat: add remote vLLM support with provider abstraction layer

- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
This commit is contained in:
Celes Renata
2026-04-23 08:17:23 +00:00
parent 63e4fb96ea
commit 117b693b19
15 changed files with 1876 additions and 77 deletions
+461
View File
@@ -0,0 +1,461 @@
"""Tests for the vLLM client, health check, config, and LLM factory."""
import json
import logging
from unittest.mock import patch
import httpx
import pytest
from services.extractor.client import OllamaClient
from services.extractor.llm_factory import build_llm_client
from services.extractor.vllm_client import VLLMClient, check_vllm_health
from services.shared.agent_config import ResolvedAgentConfig
from services.shared.config import AppConfig, OllamaConfig, VLLMConfig, load_config
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _valid_extraction_json() -> str:
"""Minimal valid extraction result as JSON string."""
return json.dumps({
"summary": "Apple beat earnings expectations.",
"companies": [
{
"ticker": "AAPL",
"company_name": "Apple Inc.",
"relevance": 0.95,
"sentiment": "positive",
"impact_score": 0.7,
"impact_horizon": "1d_30d",
"catalyst_type": "earnings",
"key_facts": ["Revenue up 12%"],
"risks": [],
"evidence_spans": ["Apple beat expectations"],
}
],
"macro_themes": ["ai_capex"],
"novelty_score": 0.6,
"confidence": 0.85,
"extraction_warnings": [],
})
def _openai_response(content: str, status: int = 200) -> httpx.Response:
"""Build a fake OpenAI-compatible /v1/chat/completions response."""
body = {
"choices": [
{"message": {"role": "assistant", "content": content}}
],
}
return httpx.Response(status, json=body)
def _make_vllm_config() -> VLLMConfig:
return VLLMConfig(
base_url="http://test-vllm:8000",
model="test-vllm-model",
timeout=10,
max_retries=2,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
max_tokens=4096,
temperature=0.7,
api_key="",
)
def _make_ollama_config() -> OllamaConfig:
return OllamaConfig(
base_url="http://test-ollama:11434",
model="test-ollama-model",
timeout=10,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
)
def _make_prompts() -> dict[str, str]:
return {"system": "You are a helpful assistant.", "user": "Extract info."}
def _make_resolved(provider: str = "vllm") -> ResolvedAgentConfig:
return ResolvedAgentConfig(
agent_id="agent-1",
variant_id=None,
model_provider=provider,
model_name="resolved-model",
system_prompt="sys",
user_prompt_template="usr",
prompt_version="v1",
temperature=0.5,
max_tokens=8192,
context_window=0,
input_token_limit=0,
token_budget=0,
timeout_seconds=60,
max_retries=2,
)
# ===================================================================
# Task 7: Unit Tests for VLLMClient
# ===================================================================
# 7.1 — VLLMClient sends correct payload to /v1/chat/completions
@pytest.mark.asyncio
async def test_vllm_sends_correct_payload():
"""VLLMClient sends POST to /v1/chat/completions with correct OpenAI payload."""
captured: dict = {}
def handler(request: httpx.Request) -> httpx.Response:
captured["url"] = str(request.url)
captured["payload"] = json.loads(request.content)
return _openai_response(_valid_extraction_json())
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
config = _make_vllm_config()
client = VLLMClient(config, http_client=http)
await client.call_llm(_make_prompts(), {})
assert captured["url"] == "http://test-vllm:8000/v1/chat/completions"
payload = captured["payload"]
assert payload["model"] == "test-vllm-model"
assert len(payload["messages"]) == 2
assert payload["messages"][0]["role"] == "system"
assert payload["messages"][1]["role"] == "user"
assert payload["max_tokens"] == 4096
assert payload["temperature"] == 0.7
await client.close()
# 7.2 — VLLMClient extracts content from choices[0].message.content
@pytest.mark.asyncio
async def test_vllm_extracts_content_from_choices():
"""VLLMClient extracts content from choices[0].message.content."""
transport = httpx.MockTransport(
lambda req: _openai_response(_valid_extraction_json())
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.raw_output == _valid_extraction_json()
assert attempt.error is None
assert attempt.validation is not None
assert attempt.validation.valid
await client.close()
# 7.3 — VLLMClient handles empty choices array → empty_model_response
@pytest.mark.asyncio
async def test_vllm_empty_choices():
"""Empty choices array returns empty_model_response error."""
body = {"choices": []}
transport = httpx.MockTransport(
lambda req: httpx.Response(200, json=body)
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error == "empty_model_response"
await client.close()
# 7.4 — VLLMClient handles HTTP timeout → timeout error
@pytest.mark.asyncio
async def test_vllm_timeout():
"""HTTP timeout returns 'timeout' error."""
def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ReadTimeout("timed out")
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error == "timeout"
assert attempt.duration_ms >= 0
await client.close()
# 7.5 — VLLMClient handles HTTP 500 → http_500 retryable error
@pytest.mark.asyncio
async def test_vllm_http_500():
"""HTTP 500 returns 'http_500' error marked as retryable."""
transport = httpx.MockTransport(
lambda req: httpx.Response(500, text="Internal Server Error")
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error == "http_500"
assert attempt.retryable is True
await client.close()
# 7.6 — VLLMClient handles HTTP 400 → http_400 non-retryable error
@pytest.mark.asyncio
async def test_vllm_http_400():
"""HTTP 400 returns 'http_400' error marked as non-retryable."""
transport = httpx.MockTransport(
lambda req: httpx.Response(400, text="Bad Request")
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error == "http_400"
assert attempt.retryable is False
await client.close()
# 7.7 — VLLMClient handles connection error → connection_error: ...
@pytest.mark.asyncio
async def test_vllm_connection_error():
"""Connection error returns 'connection_error: ...' error string."""
def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("Connection refused")
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
assert attempt.error is not None
assert attempt.error.startswith("connection_error:")
await client.close()
# 7.8 — VLLMClient applies markdown fence stripping and JSON repair
@pytest.mark.asyncio
async def test_vllm_markdown_fence_stripping_and_json_repair():
"""VLLMClient strips markdown fences and repairs JSON."""
# Wrap valid JSON in markdown fences
fenced = f"```json\n{_valid_extraction_json()}\n```"
transport = httpx.MockTransport(
lambda req: _openai_response(fenced)
)
http = httpx.AsyncClient(transport=transport)
client = VLLMClient(_make_vllm_config(), http_client=http)
attempt = await client.call_llm(_make_prompts(), {})
# Should succeed after stripping fences
assert attempt.error is None
assert attempt.validation is not None
assert attempt.validation.valid
await client.close()
# 7.9 — VLLMClient includes temperature and response_format in payload
@pytest.mark.asyncio
async def test_vllm_payload_includes_temperature_and_response_format():
"""Payload includes temperature and response_format fields."""
captured: dict = {}
def handler(request: httpx.Request) -> httpx.Response:
captured["payload"] = json.loads(request.content)
return _openai_response(_valid_extraction_json())
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
config = _make_vllm_config()
config.temperature = 0.3
client = VLLMClient(config, http_client=http)
await client.call_llm(_make_prompts(), {})
assert captured["payload"]["temperature"] == 0.3
assert captured["payload"]["response_format"] == {"type": "json_object"}
await client.close()
# 7.10 — Health check success returns True and logs INFO
@pytest.mark.asyncio
async def test_health_check_success(caplog):
"""check_vllm_health returns True and logs INFO on success."""
transport = httpx.MockTransport(
lambda req: httpx.Response(200, json={"data": [{"id": "model-1"}]})
)
with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)):
with caplog.at_level(logging.INFO, logger="vllm_client"):
result = await check_vllm_health("http://test-vllm:8000")
assert result is True
assert any("health check passed" in r.message for r in caplog.records)
# 7.11 — Health check failure returns False and logs WARNING
@pytest.mark.asyncio
async def test_health_check_failure(caplog):
"""check_vllm_health returns False and logs WARNING on failure."""
def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("Connection refused")
transport = httpx.MockTransport(handler)
with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)):
with caplog.at_level(logging.WARNING, logger="vllm_client"):
result = await check_vllm_health("http://unreachable:8000")
assert result is False
assert any("health check failed" in r.message for r in caplog.records)
# 7.12 — OllamaClient.call_llm() delegates to _call_ollama()
@pytest.mark.asyncio
async def test_ollama_call_llm_delegates():
"""OllamaClient.call_llm() delegates to _call_ollama()."""
transport = httpx.MockTransport(
lambda req: httpx.Response(
200,
json={"message": {"role": "assistant", "content": _valid_extraction_json()}},
)
)
http = httpx.AsyncClient(transport=transport)
config = _make_ollama_config()
client = OllamaClient(config, http_client=http)
prompts = _make_prompts()
schema = {}
# call_llm should produce the same result as _call_ollama
result_llm = await client.call_llm(prompts, schema)
# Both should succeed with valid extraction JSON
assert result_llm.error is None
assert result_llm.validation is not None
assert result_llm.validation.valid
assert result_llm.model == config.model
await client.close()
# 7.13 — VLLMConfig loading from environment variables
def test_vllm_config_from_env(monkeypatch):
"""VLLMConfig fields are loaded from VLLM_* environment variables."""
monkeypatch.setenv("VLLM_BASE_URL", "http://custom:9000")
monkeypatch.setenv("VLLM_MODEL", "custom-model")
monkeypatch.setenv("VLLM_TIMEOUT", "300")
monkeypatch.setenv("VLLM_MAX_RETRIES", "5")
monkeypatch.setenv("VLLM_TEMPERATURE", "0.9")
monkeypatch.setenv("VLLM_API_KEY", "secret-key")
monkeypatch.setenv("VLLM_MAX_TOKENS", "16384")
cfg = load_config()
assert cfg.vllm.base_url == "http://custom:9000"
assert cfg.vllm.model == "custom-model"
assert cfg.vllm.timeout == 300
assert cfg.vllm.max_retries == 5
assert cfg.vllm.temperature == 0.9
assert cfg.vllm.api_key == "secret-key"
assert cfg.vllm.max_tokens == 16384
# 7.14 — AppConfig includes vllm field with correct defaults
def test_appconfig_vllm_defaults():
"""AppConfig includes a vllm field with VLLMConfig defaults."""
cfg = AppConfig()
assert hasattr(cfg, "vllm")
assert isinstance(cfg.vllm, VLLMConfig)
assert cfg.vllm.base_url == "http://192.168.42.254:8000"
assert cfg.vllm.model == "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
assert cfg.vllm.timeout == 120
assert cfg.vllm.max_retries == 2
assert cfg.vllm.temperature == 0.7
assert cfg.vllm.max_tokens == 32768
assert cfg.vllm.api_key == ""
# ===================================================================
# Task 8: Unit Tests for LLM Factory
# ===================================================================
# 8.1 — Factory returns OllamaClient when provider is "ollama"
@pytest.mark.asyncio
async def test_factory_ollama_provider():
"""build_llm_client returns OllamaClient when provider is 'ollama'."""
resolved = _make_resolved(provider="ollama")
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
assert isinstance(client, OllamaClient)
await client.close()
# 8.2 — Factory returns VLLMClient when provider is "vllm"
@pytest.mark.asyncio
async def test_factory_vllm_provider():
"""build_llm_client returns VLLMClient when provider is 'vllm'."""
resolved = _make_resolved(provider="vllm")
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
assert isinstance(client, VLLMClient)
await client.close()
# 8.3 — Factory returns OllamaClient when provider is empty string (default)
@pytest.mark.asyncio
async def test_factory_empty_provider_defaults_to_ollama():
"""build_llm_client returns OllamaClient when provider is empty string."""
resolved = _make_resolved(provider="")
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
assert isinstance(client, OllamaClient)
await client.close()
# 8.4 — Factory returns OllamaClient with warning when provider is unknown
@pytest.mark.asyncio
async def test_factory_unknown_provider_warns_and_falls_back(caplog):
"""build_llm_client logs warning and returns OllamaClient for unknown provider."""
resolved = _make_resolved(provider="unknown-provider")
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
with caplog.at_level(logging.WARNING):
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
assert isinstance(client, OllamaClient)
assert any("unknown" in r.message.lower() for r in caplog.records)
await client.close()