feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference - VLLMClient for OpenAI-compatible /v1/chat/completions API - LLM client factory with provider routing (ollama/vllm) - VLLMConfig with VLLM_* environment variable loading - Updated extractor worker with health check and provider switching - Updated event classifier to use LLMClient protocol - Helm values for vLLM configuration - 18 unit tests + 6 property-based tests - Full backward compatibility preserved
This commit is contained in:
@@ -274,19 +274,19 @@ class TestParseClassificationResponse:
|
||||
|
||||
class TestClassifyGlobalEvent:
|
||||
def _make_mock_client(self, raw_output: str, error: str | None = None):
|
||||
"""Create a mock OllamaClient with configurable response."""
|
||||
"""Create a mock LLMClient with configurable response."""
|
||||
client = MagicMock()
|
||||
client._config = MagicMock()
|
||||
client._config.model = "llama3.1:8b"
|
||||
client._max_retries = 2
|
||||
client._base_delay = 0.01
|
||||
client._max_delay = 0.1
|
||||
client._backoff_multiplier = 2.0
|
||||
client._config.max_retries = 2
|
||||
client._config.retry_base_delay = 0.01
|
||||
client._config.retry_max_delay = 0.1
|
||||
client._config.retry_backoff_multiplier = 2.0
|
||||
|
||||
attempt = MagicMock()
|
||||
attempt.raw_output = raw_output
|
||||
attempt.error = error
|
||||
client._call_ollama = AsyncMock(return_value=attempt)
|
||||
client.call_llm = AsyncMock(return_value=attempt)
|
||||
return client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -314,7 +314,7 @@ class TestClassifyGlobalEvent:
|
||||
assert event.severity == "critical"
|
||||
assert event.confidence == 0.9
|
||||
assert event.source_document_id == "doc-123"
|
||||
client._call_ollama.assert_called_once()
|
||||
client.call_llm.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_error(self):
|
||||
@@ -340,11 +340,11 @@ class TestClassifyGlobalEvent:
|
||||
success_attempt.error = None
|
||||
|
||||
client = self._make_mock_client("")
|
||||
client._call_ollama = AsyncMock(side_effect=[fail_attempt, success_attempt])
|
||||
client.call_llm = AsyncMock(side_effect=[fail_attempt, success_attempt])
|
||||
|
||||
event = await classify_global_event("text", "doc-456", client)
|
||||
assert event.severity == "high"
|
||||
assert client._call_ollama.call_count == 2
|
||||
assert client.call_llm.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_after_exhausted_retries(self):
|
||||
@@ -353,12 +353,12 @@ class TestClassifyGlobalEvent:
|
||||
fail_attempt.error = "timeout"
|
||||
|
||||
client = self._make_mock_client("")
|
||||
client._call_ollama = AsyncMock(return_value=fail_attempt)
|
||||
client.call_llm = AsyncMock(return_value=fail_attempt)
|
||||
|
||||
with pytest.raises(ValueError, match="Event classification failed"):
|
||||
await classify_global_event("text", "doc-789", client)
|
||||
|
||||
assert client._call_ollama.call_count == 3 # initial + 2 retries
|
||||
assert client.call_llm.call_count == 3 # initial + 2 retries
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minio_persistence_called(self):
|
||||
|
||||
@@ -0,0 +1,296 @@
|
||||
"""Property-based tests for the LLM provider abstraction layer.
|
||||
|
||||
Feature: remote-vllm-support
|
||||
|
||||
Uses Hypothesis to validate correctness properties of the provider
|
||||
abstraction: factory routing, error classification consistency,
|
||||
VLLMClient payload structure, JSON repair idempotence, markdown
|
||||
fence stripping round-trip, and VLLMConfig default invariants.
|
||||
|
||||
Requirements: 2.1, 2.3, 2.4, 3.1, 3.4, 3.5, 5.6, 8.1, 9.5
|
||||
Design: Correctness Properties P1–P6
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from services.extractor.client import (
|
||||
OllamaClient,
|
||||
_is_retryable,
|
||||
_repair_json,
|
||||
_strip_markdown_fences,
|
||||
)
|
||||
from services.extractor.llm_factory import build_llm_client
|
||||
from services.extractor.vllm_client import VLLMClient
|
||||
from services.shared.agent_config import ResolvedAgentConfig
|
||||
from services.shared.config import OllamaConfig, VLLMConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ollama_config() -> OllamaConfig:
|
||||
return OllamaConfig(
|
||||
base_url="http://test-ollama:11434",
|
||||
model="test-ollama-model",
|
||||
timeout=10,
|
||||
retry_base_delay=0.0,
|
||||
retry_max_delay=0.0,
|
||||
retry_backoff_multiplier=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _make_vllm_config() -> VLLMConfig:
|
||||
return VLLMConfig(
|
||||
base_url="http://test-vllm:8000",
|
||||
model="test-vllm-model",
|
||||
timeout=10,
|
||||
max_retries=2,
|
||||
retry_base_delay=0.0,
|
||||
retry_max_delay=0.0,
|
||||
retry_backoff_multiplier=2.0,
|
||||
max_tokens=4096,
|
||||
temperature=0.7,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
|
||||
def _make_resolved(provider: str | None) -> ResolvedAgentConfig:
|
||||
return ResolvedAgentConfig(
|
||||
agent_id="agent-1",
|
||||
variant_id=None,
|
||||
model_provider=provider or "",
|
||||
model_name="resolved-model",
|
||||
system_prompt="sys",
|
||||
user_prompt_template="usr",
|
||||
prompt_version="v1",
|
||||
temperature=0.5,
|
||||
max_tokens=8192,
|
||||
context_window=0,
|
||||
input_token_limit=0,
|
||||
token_budget=0,
|
||||
timeout_seconds=60,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# 9.1 — Factory routing property
|
||||
# **Validates: Requirements 3.4, 3.5, 9.5**
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@given(st.sampled_from(["ollama", "vllm", "", None]))
|
||||
@settings(max_examples=100)
|
||||
def test_factory_routing_property(provider: str | None):
|
||||
"""For all model_provider in {"ollama", "vllm", "", None}, factory returns correct client type.
|
||||
|
||||
**Validates: Requirements 3.4, 3.5, 9.5**
|
||||
"""
|
||||
resolved = _make_resolved(provider)
|
||||
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
|
||||
client = build_llm_client(
|
||||
resolved, _make_ollama_config(), _make_vllm_config(), http_client=http
|
||||
)
|
||||
|
||||
if provider == "vllm":
|
||||
assert isinstance(client, VLLMClient), (
|
||||
f"Expected VLLMClient for provider={provider!r}, got {type(client).__name__}"
|
||||
)
|
||||
else:
|
||||
# "ollama", "", None all map to OllamaClient
|
||||
assert isinstance(client, OllamaClient), (
|
||||
f"Expected OllamaClient for provider={provider!r}, got {type(client).__name__}"
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# 9.2 — Error string format consistency property
|
||||
# **Validates: Requirements 5.6**
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@given(st.integers(min_value=100, max_value=599))
|
||||
@settings(max_examples=100)
|
||||
def test_is_retryable_consistency_property(status_code: int):
|
||||
"""For all HTTP status codes (100-599), _is_retryable() classifies them consistently.
|
||||
|
||||
Non-retryable: 400, 401, 403, 404, 422.
|
||||
All other http_{code} errors are retryable.
|
||||
|
||||
**Validates: Requirements 5.6**
|
||||
"""
|
||||
error_str = f"http_{status_code}"
|
||||
result = _is_retryable(error_str)
|
||||
|
||||
non_retryable_codes = {400, 401, 403, 404, 422}
|
||||
|
||||
if status_code in non_retryable_codes:
|
||||
assert result is False, (
|
||||
f"http_{status_code} should be non-retryable but _is_retryable returned True"
|
||||
)
|
||||
else:
|
||||
assert result is True, (
|
||||
f"http_{status_code} should be retryable but _is_retryable returned False"
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# 9.3 — VLLMClient request payload structure property
|
||||
# **Validates: Requirements 2.1, 8.1**
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@given(
|
||||
system=st.text(min_size=1),
|
||||
user=st.text(min_size=1),
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_vllm_payload_structure_property(system: str, user: str):
|
||||
"""For all generated prompt dicts, payload contains required OpenAI fields and excludes Ollama-specific fields.
|
||||
|
||||
**Validates: Requirements 2.1, 8.1**
|
||||
"""
|
||||
prompts = {"system": system, "user": user}
|
||||
captured: dict = {}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
captured["payload"] = json.loads(request.content)
|
||||
body = {
|
||||
"choices": [
|
||||
{"message": {"role": "assistant", "content": "{}"}}
|
||||
],
|
||||
}
|
||||
return httpx.Response(200, json=body)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
config = _make_vllm_config()
|
||||
client = VLLMClient(config, http_client=http)
|
||||
|
||||
asyncio.run(client.call_llm(prompts, {}))
|
||||
|
||||
payload = captured["payload"]
|
||||
|
||||
# Required OpenAI fields must be present
|
||||
assert "model" in payload, "Payload missing 'model' field"
|
||||
assert "messages" in payload, "Payload missing 'messages' field"
|
||||
assert "max_tokens" in payload, "Payload missing 'max_tokens' field"
|
||||
assert "temperature" in payload, "Payload missing 'temperature' field"
|
||||
|
||||
# Messages must have system and user roles
|
||||
roles = [m["role"] for m in payload["messages"]]
|
||||
assert "system" in roles, "Messages missing 'system' role"
|
||||
assert "user" in roles, "Messages missing 'user' role"
|
||||
|
||||
# Ollama-specific fields must NOT be present
|
||||
assert "think" not in payload, "Payload contains Ollama-specific 'think' field"
|
||||
assert "stream" not in payload, "Payload contains Ollama-specific 'stream' field"
|
||||
assert "options" not in payload, "Payload contains Ollama-specific 'options' field"
|
||||
|
||||
# No nested Ollama options
|
||||
for key in ("num_ctx", "num_predict"):
|
||||
assert key not in payload, f"Payload contains Ollama-specific '{key}' field"
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# 9.4 — JSON repair idempotence property
|
||||
# **Validates: Requirements 2.4**
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@given(
|
||||
st.one_of(
|
||||
st.dictionaries(st.text(max_size=20), st.text(max_size=50), max_size=5),
|
||||
st.lists(st.integers(), max_size=10),
|
||||
st.text(max_size=50),
|
||||
st.integers(),
|
||||
st.floats(allow_nan=False, allow_infinity=False),
|
||||
st.booleans(),
|
||||
st.none(),
|
||||
)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_json_repair_idempotence_property(value):
|
||||
"""For all valid JSON strings, _repair_json() is idempotent.
|
||||
|
||||
_repair_json(_repair_json(json_str)) == _repair_json(json_str)
|
||||
|
||||
**Validates: Requirements 2.4**
|
||||
"""
|
||||
json_str = json.dumps(value)
|
||||
|
||||
repaired_once = _repair_json(json_str)
|
||||
repaired_twice = _repair_json(repaired_once)
|
||||
|
||||
assert repaired_once == repaired_twice, (
|
||||
f"_repair_json is not idempotent: "
|
||||
f"first={repaired_once!r}, second={repaired_twice!r}"
|
||||
)
|
||||
|
||||
# The repaired output should be valid JSON
|
||||
json.loads(repaired_once)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# 9.5 — Markdown fence stripping round-trip property
|
||||
# **Validates: Requirements 2.3**
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@given(st.text())
|
||||
@settings(max_examples=100)
|
||||
def test_markdown_fence_stripping_roundtrip_property(s: str):
|
||||
"""For all strings, wrapping in fences then stripping recovers the original.
|
||||
|
||||
The regex trims leading/trailing whitespace around the content inside
|
||||
fences, so the round-trip recovers ``s.strip()``.
|
||||
|
||||
**Validates: Requirements 2.3**
|
||||
"""
|
||||
fenced = f"```json\n{s}\n```"
|
||||
stripped = _strip_markdown_fences(fenced)
|
||||
|
||||
assert stripped == s.strip(), (
|
||||
f"Round-trip failed: original={s!r}, stripped={stripped!r}, expected={s.strip()!r}"
|
||||
)
|
||||
|
||||
# Identity: when no fences are present, the string is returned as-is
|
||||
# (only test strings that don't look like fenced blocks themselves)
|
||||
if not s.strip().startswith("```"):
|
||||
assert _strip_markdown_fences(s) == s
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# 9.6 — VLLMConfig defaults property
|
||||
# **Validates: Requirements 3.1**
|
||||
# ===================================================================
|
||||
|
||||
|
||||
@settings(max_examples=100)
|
||||
@given(st.just(None))
|
||||
def test_vllm_config_defaults_property(_):
|
||||
"""For all default-constructed instances, invariants hold.
|
||||
|
||||
timeout > 0, max_retries >= 0, 0 <= temperature <= 2, max_tokens > 0.
|
||||
|
||||
**Validates: Requirements 3.1**
|
||||
"""
|
||||
config = VLLMConfig()
|
||||
|
||||
assert config.timeout > 0, f"timeout must be > 0, got {config.timeout}"
|
||||
assert config.max_retries >= 0, f"max_retries must be >= 0, got {config.max_retries}"
|
||||
assert 0 <= config.temperature <= 2, (
|
||||
f"temperature must be in [0, 2], got {config.temperature}"
|
||||
)
|
||||
assert config.max_tokens > 0, f"max_tokens must be > 0, got {config.max_tokens}"
|
||||
assert config.base_url, "base_url must be non-empty"
|
||||
assert config.model, "model must be non-empty"
|
||||
@@ -0,0 +1,461 @@
|
||||
"""Tests for the vLLM client, health check, config, and LLM factory."""
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.llm_factory import build_llm_client
|
||||
from services.extractor.vllm_client import VLLMClient, check_vllm_health
|
||||
from services.shared.agent_config import ResolvedAgentConfig
|
||||
from services.shared.config import AppConfig, OllamaConfig, VLLMConfig, load_config
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _valid_extraction_json() -> str:
|
||||
"""Minimal valid extraction result as JSON string."""
|
||||
return json.dumps({
|
||||
"summary": "Apple beat earnings expectations.",
|
||||
"companies": [
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"company_name": "Apple Inc.",
|
||||
"relevance": 0.95,
|
||||
"sentiment": "positive",
|
||||
"impact_score": 0.7,
|
||||
"impact_horizon": "1d_30d",
|
||||
"catalyst_type": "earnings",
|
||||
"key_facts": ["Revenue up 12%"],
|
||||
"risks": [],
|
||||
"evidence_spans": ["Apple beat expectations"],
|
||||
}
|
||||
],
|
||||
"macro_themes": ["ai_capex"],
|
||||
"novelty_score": 0.6,
|
||||
"confidence": 0.85,
|
||||
"extraction_warnings": [],
|
||||
})
|
||||
|
||||
|
||||
def _openai_response(content: str, status: int = 200) -> httpx.Response:
|
||||
"""Build a fake OpenAI-compatible /v1/chat/completions response."""
|
||||
body = {
|
||||
"choices": [
|
||||
{"message": {"role": "assistant", "content": content}}
|
||||
],
|
||||
}
|
||||
return httpx.Response(status, json=body)
|
||||
|
||||
|
||||
def _make_vllm_config() -> VLLMConfig:
|
||||
return VLLMConfig(
|
||||
base_url="http://test-vllm:8000",
|
||||
model="test-vllm-model",
|
||||
timeout=10,
|
||||
max_retries=2,
|
||||
retry_base_delay=0.0,
|
||||
retry_max_delay=0.0,
|
||||
retry_backoff_multiplier=2.0,
|
||||
max_tokens=4096,
|
||||
temperature=0.7,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
|
||||
def _make_ollama_config() -> OllamaConfig:
|
||||
return OllamaConfig(
|
||||
base_url="http://test-ollama:11434",
|
||||
model="test-ollama-model",
|
||||
timeout=10,
|
||||
retry_base_delay=0.0,
|
||||
retry_max_delay=0.0,
|
||||
retry_backoff_multiplier=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _make_prompts() -> dict[str, str]:
|
||||
return {"system": "You are a helpful assistant.", "user": "Extract info."}
|
||||
|
||||
|
||||
def _make_resolved(provider: str = "vllm") -> ResolvedAgentConfig:
|
||||
return ResolvedAgentConfig(
|
||||
agent_id="agent-1",
|
||||
variant_id=None,
|
||||
model_provider=provider,
|
||||
model_name="resolved-model",
|
||||
system_prompt="sys",
|
||||
user_prompt_template="usr",
|
||||
prompt_version="v1",
|
||||
temperature=0.5,
|
||||
max_tokens=8192,
|
||||
context_window=0,
|
||||
input_token_limit=0,
|
||||
token_budget=0,
|
||||
timeout_seconds=60,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Task 7: Unit Tests for VLLMClient
|
||||
# ===================================================================
|
||||
|
||||
|
||||
# 7.1 — VLLMClient sends correct payload to /v1/chat/completions
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_sends_correct_payload():
|
||||
"""VLLMClient sends POST to /v1/chat/completions with correct OpenAI payload."""
|
||||
captured: dict = {}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
captured["url"] = str(request.url)
|
||||
captured["payload"] = json.loads(request.content)
|
||||
return _openai_response(_valid_extraction_json())
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
config = _make_vllm_config()
|
||||
client = VLLMClient(config, http_client=http)
|
||||
|
||||
await client.call_llm(_make_prompts(), {})
|
||||
|
||||
assert captured["url"] == "http://test-vllm:8000/v1/chat/completions"
|
||||
payload = captured["payload"]
|
||||
assert payload["model"] == "test-vllm-model"
|
||||
assert len(payload["messages"]) == 2
|
||||
assert payload["messages"][0]["role"] == "system"
|
||||
assert payload["messages"][1]["role"] == "user"
|
||||
assert payload["max_tokens"] == 4096
|
||||
assert payload["temperature"] == 0.7
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.2 — VLLMClient extracts content from choices[0].message.content
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_extracts_content_from_choices():
|
||||
"""VLLMClient extracts content from choices[0].message.content."""
|
||||
transport = httpx.MockTransport(
|
||||
lambda req: _openai_response(_valid_extraction_json())
|
||||
)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||
|
||||
attempt = await client.call_llm(_make_prompts(), {})
|
||||
|
||||
assert attempt.raw_output == _valid_extraction_json()
|
||||
assert attempt.error is None
|
||||
assert attempt.validation is not None
|
||||
assert attempt.validation.valid
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.3 — VLLMClient handles empty choices array → empty_model_response
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_empty_choices():
|
||||
"""Empty choices array returns empty_model_response error."""
|
||||
body = {"choices": []}
|
||||
|
||||
transport = httpx.MockTransport(
|
||||
lambda req: httpx.Response(200, json=body)
|
||||
)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||
|
||||
attempt = await client.call_llm(_make_prompts(), {})
|
||||
|
||||
assert attempt.error == "empty_model_response"
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.4 — VLLMClient handles HTTP timeout → timeout error
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_timeout():
|
||||
"""HTTP timeout returns 'timeout' error."""
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ReadTimeout("timed out")
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||
|
||||
attempt = await client.call_llm(_make_prompts(), {})
|
||||
|
||||
assert attempt.error == "timeout"
|
||||
assert attempt.duration_ms >= 0
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.5 — VLLMClient handles HTTP 500 → http_500 retryable error
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_http_500():
|
||||
"""HTTP 500 returns 'http_500' error marked as retryable."""
|
||||
transport = httpx.MockTransport(
|
||||
lambda req: httpx.Response(500, text="Internal Server Error")
|
||||
)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||
|
||||
attempt = await client.call_llm(_make_prompts(), {})
|
||||
|
||||
assert attempt.error == "http_500"
|
||||
assert attempt.retryable is True
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.6 — VLLMClient handles HTTP 400 → http_400 non-retryable error
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_http_400():
|
||||
"""HTTP 400 returns 'http_400' error marked as non-retryable."""
|
||||
transport = httpx.MockTransport(
|
||||
lambda req: httpx.Response(400, text="Bad Request")
|
||||
)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||
|
||||
attempt = await client.call_llm(_make_prompts(), {})
|
||||
|
||||
assert attempt.error == "http_400"
|
||||
assert attempt.retryable is False
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.7 — VLLMClient handles connection error → connection_error: ...
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_connection_error():
|
||||
"""Connection error returns 'connection_error: ...' error string."""
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||
|
||||
attempt = await client.call_llm(_make_prompts(), {})
|
||||
|
||||
assert attempt.error is not None
|
||||
assert attempt.error.startswith("connection_error:")
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.8 — VLLMClient applies markdown fence stripping and JSON repair
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_markdown_fence_stripping_and_json_repair():
|
||||
"""VLLMClient strips markdown fences and repairs JSON."""
|
||||
# Wrap valid JSON in markdown fences
|
||||
fenced = f"```json\n{_valid_extraction_json()}\n```"
|
||||
transport = httpx.MockTransport(
|
||||
lambda req: _openai_response(fenced)
|
||||
)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||
|
||||
attempt = await client.call_llm(_make_prompts(), {})
|
||||
|
||||
# Should succeed after stripping fences
|
||||
assert attempt.error is None
|
||||
assert attempt.validation is not None
|
||||
assert attempt.validation.valid
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.9 — VLLMClient includes temperature and response_format in payload
|
||||
@pytest.mark.asyncio
|
||||
async def test_vllm_payload_includes_temperature_and_response_format():
|
||||
"""Payload includes temperature and response_format fields."""
|
||||
captured: dict = {}
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
captured["payload"] = json.loads(request.content)
|
||||
return _openai_response(_valid_extraction_json())
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
config = _make_vllm_config()
|
||||
config.temperature = 0.3
|
||||
client = VLLMClient(config, http_client=http)
|
||||
|
||||
await client.call_llm(_make_prompts(), {})
|
||||
|
||||
assert captured["payload"]["temperature"] == 0.3
|
||||
assert captured["payload"]["response_format"] == {"type": "json_object"}
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.10 — Health check success returns True and logs INFO
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(caplog):
|
||||
"""check_vllm_health returns True and logs INFO on success."""
|
||||
transport = httpx.MockTransport(
|
||||
lambda req: httpx.Response(200, json={"data": [{"id": "model-1"}]})
|
||||
)
|
||||
|
||||
with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)):
|
||||
with caplog.at_level(logging.INFO, logger="vllm_client"):
|
||||
result = await check_vllm_health("http://test-vllm:8000")
|
||||
|
||||
assert result is True
|
||||
assert any("health check passed" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
# 7.11 — Health check failure returns False and logs WARNING
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_failure(caplog):
|
||||
"""check_vllm_health returns False and logs WARNING on failure."""
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
|
||||
with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)):
|
||||
with caplog.at_level(logging.WARNING, logger="vllm_client"):
|
||||
result = await check_vllm_health("http://unreachable:8000")
|
||||
|
||||
assert result is False
|
||||
assert any("health check failed" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
# 7.12 — OllamaClient.call_llm() delegates to _call_ollama()
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_call_llm_delegates():
|
||||
"""OllamaClient.call_llm() delegates to _call_ollama()."""
|
||||
transport = httpx.MockTransport(
|
||||
lambda req: httpx.Response(
|
||||
200,
|
||||
json={"message": {"role": "assistant", "content": _valid_extraction_json()}},
|
||||
)
|
||||
)
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
config = _make_ollama_config()
|
||||
client = OllamaClient(config, http_client=http)
|
||||
|
||||
prompts = _make_prompts()
|
||||
schema = {}
|
||||
|
||||
# call_llm should produce the same result as _call_ollama
|
||||
result_llm = await client.call_llm(prompts, schema)
|
||||
# Both should succeed with valid extraction JSON
|
||||
assert result_llm.error is None
|
||||
assert result_llm.validation is not None
|
||||
assert result_llm.validation.valid
|
||||
assert result_llm.model == config.model
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 7.13 — VLLMConfig loading from environment variables
|
||||
def test_vllm_config_from_env(monkeypatch):
|
||||
"""VLLMConfig fields are loaded from VLLM_* environment variables."""
|
||||
monkeypatch.setenv("VLLM_BASE_URL", "http://custom:9000")
|
||||
monkeypatch.setenv("VLLM_MODEL", "custom-model")
|
||||
monkeypatch.setenv("VLLM_TIMEOUT", "300")
|
||||
monkeypatch.setenv("VLLM_MAX_RETRIES", "5")
|
||||
monkeypatch.setenv("VLLM_TEMPERATURE", "0.9")
|
||||
monkeypatch.setenv("VLLM_API_KEY", "secret-key")
|
||||
monkeypatch.setenv("VLLM_MAX_TOKENS", "16384")
|
||||
|
||||
cfg = load_config()
|
||||
|
||||
assert cfg.vllm.base_url == "http://custom:9000"
|
||||
assert cfg.vllm.model == "custom-model"
|
||||
assert cfg.vllm.timeout == 300
|
||||
assert cfg.vllm.max_retries == 5
|
||||
assert cfg.vllm.temperature == 0.9
|
||||
assert cfg.vllm.api_key == "secret-key"
|
||||
assert cfg.vllm.max_tokens == 16384
|
||||
|
||||
|
||||
# 7.14 — AppConfig includes vllm field with correct defaults
|
||||
def test_appconfig_vllm_defaults():
|
||||
"""AppConfig includes a vllm field with VLLMConfig defaults."""
|
||||
cfg = AppConfig()
|
||||
|
||||
assert hasattr(cfg, "vllm")
|
||||
assert isinstance(cfg.vllm, VLLMConfig)
|
||||
assert cfg.vllm.base_url == "http://192.168.42.254:8000"
|
||||
assert cfg.vllm.model == "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||||
assert cfg.vllm.timeout == 120
|
||||
assert cfg.vllm.max_retries == 2
|
||||
assert cfg.vllm.temperature == 0.7
|
||||
assert cfg.vllm.max_tokens == 32768
|
||||
assert cfg.vllm.api_key == ""
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Task 8: Unit Tests for LLM Factory
|
||||
# ===================================================================
|
||||
|
||||
|
||||
# 8.1 — Factory returns OllamaClient when provider is "ollama"
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_ollama_provider():
|
||||
"""build_llm_client returns OllamaClient when provider is 'ollama'."""
|
||||
resolved = _make_resolved(provider="ollama")
|
||||
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
|
||||
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
|
||||
|
||||
assert isinstance(client, OllamaClient)
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 8.2 — Factory returns VLLMClient when provider is "vllm"
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_vllm_provider():
|
||||
"""build_llm_client returns VLLMClient when provider is 'vllm'."""
|
||||
resolved = _make_resolved(provider="vllm")
|
||||
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
|
||||
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
|
||||
|
||||
assert isinstance(client, VLLMClient)
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 8.3 — Factory returns OllamaClient when provider is empty string (default)
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_empty_provider_defaults_to_ollama():
|
||||
"""build_llm_client returns OllamaClient when provider is empty string."""
|
||||
resolved = _make_resolved(provider="")
|
||||
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
|
||||
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
|
||||
|
||||
assert isinstance(client, OllamaClient)
|
||||
|
||||
await client.close()
|
||||
|
||||
|
||||
# 8.4 — Factory returns OllamaClient with warning when provider is unknown
|
||||
@pytest.mark.asyncio
|
||||
async def test_factory_unknown_provider_warns_and_falls_back(caplog):
|
||||
"""build_llm_client logs warning and returns OllamaClient for unknown provider."""
|
||||
resolved = _make_resolved(provider="unknown-provider")
|
||||
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||
http = httpx.AsyncClient(transport=transport)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
|
||||
|
||||
assert isinstance(client, OllamaClient)
|
||||
assert any("unknown" in r.message.lower() for r in caplog.records)
|
||||
|
||||
await client.close()
|
||||
Reference in New Issue
Block a user