Files
stonks-oracle/services/extractor/vllm_client.py
T

253 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""vLLM client for OpenAI-compatible chat completions.
Sends structured extraction requests to a remote vLLM server via the
``/v1/chat/completions`` endpoint. Reuses the same markdown-fence
stripping, JSON repair, and error-string conventions as OllamaClient
so that ``_is_retryable()`` works without modification.
Requirements: 2.12.10, 7.17.4
"""
from __future__ import annotations
import asyncio
import logging
import time
import httpx
from services.extractor.client import (
ExtractionAttempt,
ExtractionResponse,
_compute_backoff,
_is_retryable,
_repair_json,
_strip_markdown_fences,
)
from services.extractor.prompts import (
build_extraction_prompt,
get_json_schema,
get_prompt_metadata,
)
from services.extractor.schemas import validate_extraction
from services.shared.config import VLLMConfig
logger = logging.getLogger("vllm_client")
class VLLMClient:
"""Async client for vLLM OpenAI-compatible chat completions.
Satisfies the ``LLMClient`` protocol defined in
``services.shared.llm_protocol``.
"""
_config: VLLMConfig
_http: httpx.AsyncClient
_owns_client: bool
def __init__(
self,
config: VLLMConfig,
http_client: httpx.AsyncClient | None = None,
) -> None:
self._config = config
self._owns_client = http_client is None
self._http = http_client or httpx.AsyncClient(
timeout=httpx.Timeout(config.timeout, read=config.timeout),
)
# ------------------------------------------------------------------
# LLMClient protocol
# ------------------------------------------------------------------
async def call_llm(
self,
prompts: dict[str, str],
json_schema: dict[str, object],
document_text: str = "",
) -> ExtractionAttempt:
"""Send a chat completion request to the vLLM server.
Builds an OpenAI-compatible payload, posts to
``/v1/chat/completions``, and parses the response through the
same markdown-fence / JSON-repair pipeline used by OllamaClient.
"""
attempt = ExtractionAttempt(model=self._config.model)
start = time.monotonic()
headers: dict[str, str] = {}
if self._config.api_key:
headers["Authorization"] = f"Bearer {self._config.api_key}"
payload: dict[str, object] = {
"model": self._config.model,
"messages": [
{"role": "system", "content": prompts["system"]},
{"role": "user", "content": prompts["user"]},
],
"max_tokens": self._config.max_tokens,
"temperature": self._config.temperature,
"response_format": {"type": "json_object"},
}
url = f"{self._config.base_url}/v1/chat/completions"
logger.info(
"vLLM POST %s model=%s input_chars=%d",
url,
self._config.model,
len(prompts.get("user", "")),
)
try:
resp = await self._http.post(url, json=payload, headers=headers)
resp.raise_for_status()
except httpx.TimeoutException:
attempt.error = "timeout"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPStatusError as exc:
attempt.error = f"http_{exc.response.status_code}"
attempt.retryable = _is_retryable(attempt.error)
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPError as exc:
attempt.error = f"connection_error: {exc}"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
attempt.duration_ms = int((time.monotonic() - start) * 1000)
# --- Parse the OpenAI-compatible response ---
try:
data = resp.json()
except Exception:
attempt.error = "invalid_response_json"
attempt.raw_output = resp.text[:2000]
return attempt
choices = data.get("choices") or []
if not choices:
attempt.error = "empty_model_response"
return attempt
content = (
choices[0].get("message", {}).get("content", "")
if isinstance(choices[0], dict)
else ""
)
attempt.raw_output = content
if not content:
attempt.error = "empty_model_response"
return attempt
# Strip markdown fences if present
content = _strip_markdown_fences(content)
# Repair malformed JSON
content = _repair_json(content)
# Validate against extraction schema
attempt.validation = validate_extraction(content, document_text=document_text)
if not attempt.validation.valid:
attempt.error = "; ".join(attempt.validation.errors)
return attempt
async def extract(
self,
document_text: str,
document_type: str = "article",
document_id: str = "",
known_tickers: list[str] | None = None,
) -> ExtractionResponse:
"""Send a document to vLLM for structured intelligence extraction.
Mirrors ``OllamaClient.extract()`` — retries up to ``max_retries``
times with exponential backoff, preserving each attempt for audit.
"""
prompts = build_extraction_prompt(
document_text=document_text,
document_type=document_type,
document_id=document_id,
known_tickers=known_tickers,
)
json_schema = get_json_schema()
prompt_meta = get_prompt_metadata()
response = ExtractionResponse(
prompt_metadata=prompt_meta,
model=self._config.model,
)
max_retries = self._config.max_retries
total_start = time.monotonic()
for attempt_num in range(max_retries + 1):
attempt = await self.call_llm(prompts, json_schema, document_text)
response.attempts.append(attempt)
if attempt.error is None and attempt.validation and attempt.validation.valid:
response.success = True
response.result = attempt.validation.parsed
break
if not _is_retryable(attempt.error):
attempt.retryable = False
logger.warning(
"Non-retryable error for doc %s: %s — stopping retries",
document_id or "unknown",
attempt.error,
)
break
if attempt_num < max_retries:
delay = _compute_backoff(
attempt_num,
self._config.retry_base_delay,
self._config.retry_max_delay,
self._config.retry_backoff_multiplier,
)
logger.warning(
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1,
max_retries + 1,
document_id or "unknown",
attempt.error or "validation failed",
delay,
)
await asyncio.sleep(delay)
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
return response
async def close(self) -> None:
"""Release the underlying ``httpx.AsyncClient`` if we own it."""
if self._owns_client:
await self._http.aclose()
# ------------------------------------------------------------------
# Standalone health check
# ------------------------------------------------------------------
async def check_vllm_health(base_url: str, timeout: float = 10.0) -> bool:
"""Verify the vLLM server is reachable by querying ``/v1/models``.
Returns ``True`` when the server responds with HTTP 200, ``False``
otherwise. Logs INFO on success and WARNING on failure.
Requirements: 7.17.4
"""
url = f"{base_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
resp = await client.get(url)
resp.raise_for_status()
logger.info("vLLM health check passed: %s", url)
return True
except Exception as exc:
logger.warning("vLLM health check failed for %s: %s", url, exc)
return False