feat: add remote vLLM support with provider abstraction layer

- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
This commit is contained in:
Celes Renata
2026-04-23 08:17:23 +00:00
parent 63e4fb96ea
commit 117b693b19
15 changed files with 1876 additions and 77 deletions
+296
View File
@@ -0,0 +1,296 @@
"""Property-based tests for the LLM provider abstraction layer.
Feature: remote-vllm-support
Uses Hypothesis to validate correctness properties of the provider
abstraction: factory routing, error classification consistency,
VLLMClient payload structure, JSON repair idempotence, markdown
fence stripping round-trip, and VLLMConfig default invariants.
Requirements: 2.1, 2.3, 2.4, 3.1, 3.4, 3.5, 5.6, 8.1, 9.5
Design: Correctness Properties P1P6
"""
from __future__ import annotations
import asyncio
import json
import httpx
from hypothesis import given, settings
from hypothesis import strategies as st
from services.extractor.client import (
OllamaClient,
_is_retryable,
_repair_json,
_strip_markdown_fences,
)
from services.extractor.llm_factory import build_llm_client
from services.extractor.vllm_client import VLLMClient
from services.shared.agent_config import ResolvedAgentConfig
from services.shared.config import OllamaConfig, VLLMConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ollama_config() -> OllamaConfig:
return OllamaConfig(
base_url="http://test-ollama:11434",
model="test-ollama-model",
timeout=10,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
)
def _make_vllm_config() -> VLLMConfig:
return VLLMConfig(
base_url="http://test-vllm:8000",
model="test-vllm-model",
timeout=10,
max_retries=2,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
max_tokens=4096,
temperature=0.7,
api_key="",
)
def _make_resolved(provider: str | None) -> ResolvedAgentConfig:
return ResolvedAgentConfig(
agent_id="agent-1",
variant_id=None,
model_provider=provider or "",
model_name="resolved-model",
system_prompt="sys",
user_prompt_template="usr",
prompt_version="v1",
temperature=0.5,
max_tokens=8192,
context_window=0,
input_token_limit=0,
token_budget=0,
timeout_seconds=60,
max_retries=2,
)
# ===================================================================
# 9.1 — Factory routing property
# **Validates: Requirements 3.4, 3.5, 9.5**
# ===================================================================
@given(st.sampled_from(["ollama", "vllm", "", None]))
@settings(max_examples=100)
def test_factory_routing_property(provider: str | None):
"""For all model_provider in {"ollama", "vllm", "", None}, factory returns correct client type.
**Validates: Requirements 3.4, 3.5, 9.5**
"""
resolved = _make_resolved(provider)
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
client = build_llm_client(
resolved, _make_ollama_config(), _make_vllm_config(), http_client=http
)
if provider == "vllm":
assert isinstance(client, VLLMClient), (
f"Expected VLLMClient for provider={provider!r}, got {type(client).__name__}"
)
else:
# "ollama", "", None all map to OllamaClient
assert isinstance(client, OllamaClient), (
f"Expected OllamaClient for provider={provider!r}, got {type(client).__name__}"
)
# ===================================================================
# 9.2 — Error string format consistency property
# **Validates: Requirements 5.6**
# ===================================================================
@given(st.integers(min_value=100, max_value=599))
@settings(max_examples=100)
def test_is_retryable_consistency_property(status_code: int):
"""For all HTTP status codes (100-599), _is_retryable() classifies them consistently.
Non-retryable: 400, 401, 403, 404, 422.
All other http_{code} errors are retryable.
**Validates: Requirements 5.6**
"""
error_str = f"http_{status_code}"
result = _is_retryable(error_str)
non_retryable_codes = {400, 401, 403, 404, 422}
if status_code in non_retryable_codes:
assert result is False, (
f"http_{status_code} should be non-retryable but _is_retryable returned True"
)
else:
assert result is True, (
f"http_{status_code} should be retryable but _is_retryable returned False"
)
# ===================================================================
# 9.3 — VLLMClient request payload structure property
# **Validates: Requirements 2.1, 8.1**
# ===================================================================
@given(
system=st.text(min_size=1),
user=st.text(min_size=1),
)
@settings(max_examples=100)
def test_vllm_payload_structure_property(system: str, user: str):
"""For all generated prompt dicts, payload contains required OpenAI fields and excludes Ollama-specific fields.
**Validates: Requirements 2.1, 8.1**
"""
prompts = {"system": system, "user": user}
captured: dict = {}
def handler(request: httpx.Request) -> httpx.Response:
captured["payload"] = json.loads(request.content)
body = {
"choices": [
{"message": {"role": "assistant", "content": "{}"}}
],
}
return httpx.Response(200, json=body)
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
config = _make_vllm_config()
client = VLLMClient(config, http_client=http)
asyncio.run(client.call_llm(prompts, {}))
payload = captured["payload"]
# Required OpenAI fields must be present
assert "model" in payload, "Payload missing 'model' field"
assert "messages" in payload, "Payload missing 'messages' field"
assert "max_tokens" in payload, "Payload missing 'max_tokens' field"
assert "temperature" in payload, "Payload missing 'temperature' field"
# Messages must have system and user roles
roles = [m["role"] for m in payload["messages"]]
assert "system" in roles, "Messages missing 'system' role"
assert "user" in roles, "Messages missing 'user' role"
# Ollama-specific fields must NOT be present
assert "think" not in payload, "Payload contains Ollama-specific 'think' field"
assert "stream" not in payload, "Payload contains Ollama-specific 'stream' field"
assert "options" not in payload, "Payload contains Ollama-specific 'options' field"
# No nested Ollama options
for key in ("num_ctx", "num_predict"):
assert key not in payload, f"Payload contains Ollama-specific '{key}' field"
# ===================================================================
# 9.4 — JSON repair idempotence property
# **Validates: Requirements 2.4**
# ===================================================================
@given(
st.one_of(
st.dictionaries(st.text(max_size=20), st.text(max_size=50), max_size=5),
st.lists(st.integers(), max_size=10),
st.text(max_size=50),
st.integers(),
st.floats(allow_nan=False, allow_infinity=False),
st.booleans(),
st.none(),
)
)
@settings(max_examples=100)
def test_json_repair_idempotence_property(value):
"""For all valid JSON strings, _repair_json() is idempotent.
_repair_json(_repair_json(json_str)) == _repair_json(json_str)
**Validates: Requirements 2.4**
"""
json_str = json.dumps(value)
repaired_once = _repair_json(json_str)
repaired_twice = _repair_json(repaired_once)
assert repaired_once == repaired_twice, (
f"_repair_json is not idempotent: "
f"first={repaired_once!r}, second={repaired_twice!r}"
)
# The repaired output should be valid JSON
json.loads(repaired_once)
# ===================================================================
# 9.5 — Markdown fence stripping round-trip property
# **Validates: Requirements 2.3**
# ===================================================================
@given(st.text())
@settings(max_examples=100)
def test_markdown_fence_stripping_roundtrip_property(s: str):
"""For all strings, wrapping in fences then stripping recovers the original.
The regex trims leading/trailing whitespace around the content inside
fences, so the round-trip recovers ``s.strip()``.
**Validates: Requirements 2.3**
"""
fenced = f"```json\n{s}\n```"
stripped = _strip_markdown_fences(fenced)
assert stripped == s.strip(), (
f"Round-trip failed: original={s!r}, stripped={stripped!r}, expected={s.strip()!r}"
)
# Identity: when no fences are present, the string is returned as-is
# (only test strings that don't look like fenced blocks themselves)
if not s.strip().startswith("```"):
assert _strip_markdown_fences(s) == s
# ===================================================================
# 9.6 — VLLMConfig defaults property
# **Validates: Requirements 3.1**
# ===================================================================
@settings(max_examples=100)
@given(st.just(None))
def test_vllm_config_defaults_property(_):
"""For all default-constructed instances, invariants hold.
timeout > 0, max_retries >= 0, 0 <= temperature <= 2, max_tokens > 0.
**Validates: Requirements 3.1**
"""
config = VLLMConfig()
assert config.timeout > 0, f"timeout must be > 0, got {config.timeout}"
assert config.max_retries >= 0, f"max_retries must be >= 0, got {config.max_retries}"
assert 0 <= config.temperature <= 2, (
f"temperature must be in [0, 2], got {config.temperature}"
)
assert config.max_tokens > 0, f"max_tokens must be > 0, got {config.max_tokens}"
assert config.base_url, "base_url must be non-empty"
assert config.model, "model must be non-empty"