Files
stonks-oracle/services/extractor/vllm_client.py
T
Celes Renata 117b693b19 feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
2026-04-23 08:17:23 +00:00

178 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""vLLM client for OpenAI-compatible chat completions.
Sends structured extraction requests to a remote vLLM server via the
``/v1/chat/completions`` endpoint. Reuses the same markdown-fence
stripping, JSON repair, and error-string conventions as OllamaClient
so that ``_is_retryable()`` works without modification.
Requirements: 2.12.10, 7.17.4
"""
from __future__ import annotations
import logging
import time
import httpx
from services.extractor.client import (
ExtractionAttempt,
_is_retryable,
_repair_json,
_strip_markdown_fences,
)
from services.extractor.schemas import validate_extraction
from services.shared.config import VLLMConfig
logger = logging.getLogger("vllm_client")
class VLLMClient:
"""Async client for vLLM OpenAI-compatible chat completions.
Satisfies the ``LLMClient`` protocol defined in
``services.shared.llm_protocol``.
"""
_config: VLLMConfig
_http: httpx.AsyncClient
_owns_client: bool
def __init__(
self,
config: VLLMConfig,
http_client: httpx.AsyncClient | None = None,
) -> None:
self._config = config
self._owns_client = http_client is None
self._http = http_client or httpx.AsyncClient(
timeout=httpx.Timeout(config.timeout, read=config.timeout),
)
# ------------------------------------------------------------------
# LLMClient protocol
# ------------------------------------------------------------------
async def call_llm(
self,
prompts: dict[str, str],
json_schema: dict[str, object],
document_text: str = "",
) -> ExtractionAttempt:
"""Send a chat completion request to the vLLM server.
Builds an OpenAI-compatible payload, posts to
``/v1/chat/completions``, and parses the response through the
same markdown-fence / JSON-repair pipeline used by OllamaClient.
"""
attempt = ExtractionAttempt(model=self._config.model)
start = time.monotonic()
headers: dict[str, str] = {}
if self._config.api_key:
headers["Authorization"] = f"Bearer {self._config.api_key}"
payload: dict[str, object] = {
"model": self._config.model,
"messages": [
{"role": "system", "content": prompts["system"]},
{"role": "user", "content": prompts["user"]},
],
"max_tokens": self._config.max_tokens,
"temperature": self._config.temperature,
"response_format": {"type": "json_object"},
}
url = f"{self._config.base_url}/v1/chat/completions"
logger.info(
"vLLM POST %s model=%s input_chars=%d",
url,
self._config.model,
len(prompts.get("user", "")),
)
try:
resp = await self._http.post(url, json=payload, headers=headers)
resp.raise_for_status()
except httpx.TimeoutException:
attempt.error = "timeout"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPStatusError as exc:
attempt.error = f"http_{exc.response.status_code}"
attempt.retryable = _is_retryable(attempt.error)
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPError as exc:
attempt.error = f"connection_error: {exc}"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
attempt.duration_ms = int((time.monotonic() - start) * 1000)
# --- Parse the OpenAI-compatible response ---
try:
data = resp.json()
except Exception:
attempt.error = "invalid_response_json"
attempt.raw_output = resp.text[:2000]
return attempt
choices = data.get("choices") or []
if not choices:
attempt.error = "empty_model_response"
return attempt
content = (
choices[0].get("message", {}).get("content", "")
if isinstance(choices[0], dict)
else ""
)
attempt.raw_output = content
if not content:
attempt.error = "empty_model_response"
return attempt
# Strip markdown fences if present
content = _strip_markdown_fences(content)
# Repair malformed JSON
content = _repair_json(content)
# Validate against extraction schema
attempt.validation = validate_extraction(content, document_text=document_text)
if not attempt.validation.valid:
attempt.error = "; ".join(attempt.validation.errors)
return attempt
async def close(self) -> None:
"""Release the underlying ``httpx.AsyncClient`` if we own it."""
if self._owns_client:
await self._http.aclose()
# ------------------------------------------------------------------
# Standalone health check
# ------------------------------------------------------------------
async def check_vllm_health(base_url: str, timeout: float = 10.0) -> bool:
"""Verify the vLLM server is reachable by querying ``/v1/models``.
Returns ``True`` when the server responds with HTTP 200, ``False``
otherwise. Logs INFO on success and WARNING on failure.
Requirements: 7.17.4
"""
url = f"{base_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
resp = await client.get(url)
resp.raise_for_status()
logger.info("vLLM health check passed: %s", url)
return True
except Exception as exc:
logger.warning("vLLM health check failed for %s: %s", url, exc)
return False