feat: add remote vLLM support with provider abstraction layer

- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
This commit is contained in:
Celes Renata
2026-04-23 08:17:23 +00:00
parent 63e4fb96ea
commit 117b693b19
15 changed files with 1876 additions and 77 deletions
+13
View File
@@ -155,6 +155,19 @@ class OllamaClient:
if self._owns_client:
await self._http.aclose()
async def call_llm(
self,
prompts: dict[str, str],
json_schema: dict[str, object],
document_text: str = "",
) -> ExtractionAttempt:
"""Public LLM client interface — delegates to _call_ollama().
Satisfies the LLMClient protocol so OllamaClient can be used
interchangeably with VLLMClient.
"""
return await self._call_ollama(prompts, json_schema, document_text)
async def extract(
self,
document_text: str,
+27 -20
View File
@@ -1,13 +1,14 @@
"""Event classifier module for macro news articles.
Classifies global/geopolitical news articles into structured GlobalEvent
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
existing OllamaClient for inference and retry logic.
objects using an LLM client (Ollama or vLLM) with a dedicated prompt and
JSON schema. Uses the LLMClient protocol for provider-agnostic inference
and retry logic.
Persists classification prompts, raw outputs, and final events to MinIO
and PostgreSQL for audit and downstream interpolation.
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
Requirements: 1.4, 2.1, 2.2, 2.3, 2.4, 2.5, 6.2
"""
from __future__ import annotations
@@ -24,6 +25,8 @@ import asyncpg
from minio import Minio
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
from services.shared.config import VLLMConfig
from services.shared.llm_protocol import LLMClient
from services.shared.schemas import (
EstimatedDuration,
ImpactType,
@@ -281,6 +284,7 @@ def _parse_classification_response(
raw_json: str,
document_id: str,
model_name: str,
provider: str = "ollama",
) -> GlobalEvent:
"""Parse raw Ollama JSON output into a GlobalEvent.
@@ -345,7 +349,7 @@ def _parse_classification_response(
confidence=confidence,
source_document_id=document_id,
model_metadata=ModelMetadata(
provider="ollama",
provider=provider,
model_name=model_name,
prompt_version=PROMPT_VERSION,
schema_version=SCHEMA_VERSION,
@@ -479,21 +483,21 @@ async def persist_global_event(
async def classify_global_event(
normalized_text: str,
document_id: str,
ollama_client: Any,
client: LLMClient,
*,
pool: asyncpg.Pool | None = None,
minio_client: Minio | None = None,
) -> GlobalEvent:
"""Classify a macro news article into a GlobalEvent using Ollama.
"""Classify a macro news article into a GlobalEvent using an LLM.
Uses the existing OllamaClient's streaming infrastructure with a
dedicated event classification prompt and JSON schema. Follows the
same retry policy as document extraction.
Uses the LLMClient protocol's call_llm() method with a dedicated
event classification prompt and JSON schema. Follows the same retry
policy as document extraction.
Resolves runtime config for the "event-classifier" agent slug from
the database, preferring an active variant's model_name and
system_prompt if one exists. Falls back to the OllamaClient's
existing config if resolution fails.
system_prompt if one exists. Falls back to the client's existing
config if resolution fails.
Persists prompt, raw output, and final event to MinIO and PostgreSQL
when the respective clients are provided.
@@ -501,7 +505,7 @@ async def classify_global_event(
Args:
normalized_text: Cleaned text content of the macro article.
document_id: UUID of the source document.
ollama_client: An OllamaClient instance (from services.extractor.client).
client: An LLMClient instance (OllamaClient or VLLMClient).
pool: Optional asyncpg pool for PostgreSQL persistence.
minio_client: Optional MinIO client for artifact persistence.
@@ -528,7 +532,10 @@ async def classify_global_event(
prompts = build_event_classification_prompt(normalized_text)
json_schema = get_event_json_schema()
model_name = ollama_client._config.model
model_name = client._config.model
# Detect provider from client config type
provider = "vllm" if isinstance(client._config, VLLMConfig) else "ollama"
# Override model_name and system_prompt from resolved config
if resolved is not None:
@@ -562,16 +569,16 @@ async def classify_global_event(
except Exception:
logger.exception("Failed to upload classification prompt for doc %s", document_id)
# Call Ollama using the client's internal _call_ollama method
# Call LLM using the client's call_llm method
# We reuse the retry logic pattern from OllamaClient.extract()
max_retries = ollama_client._max_retries
max_retries = client._config.max_retries
if resolved is not None:
max_retries = resolved.max_retries
last_error: str | None = None
raw_output = ""
for attempt_num in range(max_retries + 1):
attempt = await ollama_client._call_ollama(prompts, json_schema)
attempt = await client.call_llm(prompts, json_schema)
raw_output = attempt.raw_output
# _call_ollama validates against the *extraction* schema, which
@@ -581,7 +588,7 @@ async def classify_global_event(
# Try to parse the response
try:
event = _parse_classification_response(
raw_output, document_id, model_name,
raw_output, document_id, model_name, provider=provider,
)
# Persist result to MinIO
@@ -648,10 +655,10 @@ async def classify_global_event(
# Retry with backoff
if attempt_num < max_retries:
delay = ollama_client._base_delay * (
ollama_client._backoff_multiplier ** attempt_num
delay = client._config.retry_base_delay * (
client._config.retry_backoff_multiplier ** attempt_num
)
delay = min(delay, ollama_client._max_delay)
delay = min(delay, client._config.retry_max_delay)
logger.warning(
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
+135
View File
@@ -0,0 +1,135 @@
"""LLM client factory for provider-based routing.
Returns the appropriate LLM client (OllamaClient or VLLMClient) based on
the resolved ``model_provider`` from the agent config. Falls back to
OllamaClient for unknown or missing providers.
Requirements: 3.4, 3.5, 3.6, 9.5
Design: LLM Client Factory
"""
from __future__ import annotations
import logging
import httpx
from services.extractor.client import OllamaClient
from services.extractor.vllm_client import VLLMClient
from services.shared.agent_config import ResolvedAgentConfig
from services.shared.config import OllamaConfig, VLLMConfig
from services.shared.llm_protocol import LLMClient
logger = logging.getLogger(__name__)
# Providers that map to OllamaClient (including empty / None).
_OLLAMA_PROVIDERS = frozenset({"ollama", "", None})
def build_config_from_resolved(
resolved: ResolvedAgentConfig,
base_ollama: OllamaConfig,
base_vllm: VLLMConfig,
) -> OllamaConfig | VLLMConfig:
"""Build a provider-specific config from a resolved agent config.
Merges the resolved agent-level overrides (model_name, timeout, retries,
max_tokens, context_window) with the base environment config (base_url,
retry delays, provider-specific defaults).
Args:
resolved: Runtime config resolved from the database.
base_ollama: Base OllamaConfig loaded from environment variables.
base_vllm: Base VLLMConfig loaded from environment variables.
Returns:
An ``OllamaConfig`` or ``VLLMConfig`` depending on the provider.
"""
provider = (resolved.model_provider or "").strip().lower()
if provider == "vllm":
return VLLMConfig(
base_url=base_vllm.base_url,
model=resolved.model_name,
timeout=resolved.timeout_seconds,
max_retries=resolved.max_retries,
retry_base_delay=base_vllm.retry_base_delay,
retry_max_delay=base_vllm.retry_max_delay,
retry_backoff_multiplier=base_vllm.retry_backoff_multiplier,
max_tokens=resolved.max_tokens,
temperature=base_vllm.temperature,
api_key=base_vllm.api_key,
)
# Default: Ollama config (covers "ollama", "", None, and unknown)
if provider not in _OLLAMA_PROVIDERS:
logger.warning(
"Unknown model_provider %r for agent %s — treating as ollama",
resolved.model_provider,
resolved.agent_id,
)
return OllamaConfig(
base_url=base_ollama.base_url,
model=resolved.model_name,
timeout=resolved.timeout_seconds,
max_retries=resolved.max_retries,
retry_base_delay=base_ollama.retry_base_delay,
retry_max_delay=base_ollama.retry_max_delay,
retry_backoff_multiplier=base_ollama.retry_backoff_multiplier,
max_tokens=resolved.max_tokens,
stall_timeout=base_ollama.stall_timeout,
loop_window=base_ollama.loop_window,
loop_threshold=base_ollama.loop_threshold,
context_window=resolved.context_window,
)
def build_llm_client(
resolved: ResolvedAgentConfig | None,
ollama_config: OllamaConfig,
vllm_config: VLLMConfig,
http_client: httpx.AsyncClient | None = None,
) -> LLMClient:
"""Return the appropriate LLM client based on the resolved provider.
Provider routing:
- ``None`` / ``""`` / ``"ollama"`` → :class:`OllamaClient`
- ``"vllm"`` → :class:`VLLMClient`
- Unknown value → log warning, fall back to :class:`OllamaClient`
When *resolved* is ``None`` (DB lookup failed), the base
``ollama_config`` is used directly.
Args:
resolved: Resolved agent config (may be ``None``).
ollama_config: Base OllamaConfig from environment.
vllm_config: Base VLLMConfig from environment.
http_client: Optional shared httpx client for testing.
Returns:
An LLM client satisfying the :class:`LLMClient` protocol.
"""
if resolved is None:
logger.info("No resolved agent config — defaulting to OllamaClient")
return OllamaClient(ollama_config, http_client=http_client)
provider = (resolved.model_provider or "").strip().lower()
if provider == "vllm":
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
logger.info(
"Building VLLMClient for agent %s (model=%s)",
resolved.agent_id,
cfg.model, # type: ignore[union-attr]
)
return VLLMClient(cfg, http_client=http_client) # type: ignore[arg-type]
if provider not in _OLLAMA_PROVIDERS:
logger.warning(
"Unknown model_provider %r for agent %s — falling back to OllamaClient",
resolved.model_provider,
resolved.agent_id,
)
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
return OllamaClient(cfg, http_client=http_client) # type: ignore[arg-type]
+106 -46
View File
@@ -15,11 +15,13 @@ from services.aggregation.interpolation import (
filter_low_confidence_events,
persist_macro_impact_records,
)
from services.extractor.client import OllamaClient
from services.extractor.event_classifier import classify_global_event
from services.extractor.llm_factory import build_config_from_resolved, build_llm_client
from services.extractor.vllm_client import check_vllm_health
from services.extractor.worker import persist_extraction
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
from services.shared.config import OllamaConfig, load_config
from services.shared.llm_protocol import LLMClient
from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import (
QUEUE_AGGREGATION,
@@ -31,11 +33,22 @@ from services.shared.redis_keys import (
logger = logging.getLogger("extractor_main")
def _get_provider(resolved: ResolvedAgentConfig | None) -> str:
"""Return the normalised provider string for a resolved config."""
if resolved is None:
return "ollama"
return (resolved.model_provider or "").strip().lower() or "ollama"
def _build_ollama_config_from_resolved(
resolved: ResolvedAgentConfig,
base_config: OllamaConfig,
) -> OllamaConfig:
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings."""
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings.
Kept for backward compatibility — the factory's ``build_config_from_resolved``
is now the primary path.
"""
return OllamaConfig(
base_url=base_config.base_url,
model=resolved.model_name,
@@ -239,7 +252,7 @@ async def _process_macro_classification(
*,
pool: asyncpg.Pool,
minio_client: Minio,
ollama: OllamaClient,
ollama: LLMClient,
redis_client: aioredis.Redis,
document_id: str,
text: str,
@@ -258,7 +271,7 @@ async def _process_macro_classification(
event = await classify_global_event(
normalized_text=text,
document_id=document_id,
ollama_client=ollama,
client=ollama,
pool=pool,
minio_client=minio_client,
)
@@ -329,48 +342,69 @@ async def main() -> None:
# Resolve extractor config from DB (active variant override + TTL cache)
resolver = AgentConfigResolver(pool, ttl_seconds=60)
resolved_config: ResolvedAgentConfig | None = None
extractor_ollama_config = config.ollama
extractor_provider = "ollama"
try:
resolved_config = await resolver.resolve("document-extractor")
if resolved_config is not None:
extractor_ollama_config = _build_ollama_config_from_resolved(
resolved_config, config.ollama,
)
extractor_provider = _get_provider(resolved_config)
logger.info(
"Extractor using resolved config: model=%s variant=%s",
resolved_config.model_name, resolved_config.variant_id,
"Extractor using resolved config: model=%s variant=%s provider=%s",
resolved_config.model_name, resolved_config.variant_id, extractor_provider,
)
else:
logger.info("No DB config for document-extractor — using env defaults")
except Exception:
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
ollama = OllamaClient(extractor_ollama_config)
# vLLM health check at startup when provider is vllm (Requirement 7.17.3)
if extractor_provider == "vllm":
healthy = await check_vllm_health(config.vllm.base_url)
if not healthy:
logger.warning(
"vLLM health check failed at startup — falling back to Ollama for extractor",
)
extractor_provider = "ollama"
# Override resolved config provider so factory builds OllamaClient
resolved_config = None
extractor_client: LLMClient = build_llm_client(
resolved_config, config.ollama, config.vllm,
)
# Resolve event classifier config separately (may use different model)
classifier_resolved: ResolvedAgentConfig | None = None
classifier_ollama_config = config.ollama
classifier_provider = "ollama"
try:
classifier_resolved = await resolver.resolve("event-classifier")
if classifier_resolved is not None:
classifier_ollama_config = _build_ollama_config_from_resolved(
classifier_resolved, config.ollama,
)
classifier_provider = _get_provider(classifier_resolved)
logger.info(
"Event classifier using resolved config: model=%s variant=%s",
classifier_resolved.model_name, classifier_resolved.variant_id,
"Event classifier using resolved config: model=%s variant=%s provider=%s",
classifier_resolved.model_name, classifier_resolved.variant_id, classifier_provider,
)
else:
logger.info("No DB config for event-classifier — using extractor config")
except Exception:
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
# Use a separate OllamaClient for the classifier if it has a different model
classifier_ollama: OllamaClient
if classifier_ollama_config.model != extractor_ollama_config.model:
classifier_ollama = OllamaClient(classifier_ollama_config)
# vLLM health check for classifier if it uses vllm and extractor didn't already check
if classifier_provider == "vllm" and extractor_provider != "vllm":
healthy = await check_vllm_health(config.vllm.base_url)
if not healthy:
logger.warning(
"vLLM health check failed at startup — falling back to Ollama for classifier",
)
classifier_provider = "ollama"
classifier_resolved = None
# Build classifier client — share with extractor when configs match
classifier_client: LLMClient
if classifier_resolved is not None or classifier_provider != extractor_provider:
classifier_client = build_llm_client(
classifier_resolved, config.ollama, config.vllm,
)
else:
classifier_ollama = ollama
classifier_client = extractor_client
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION)
@@ -441,40 +475,66 @@ async def main() -> None:
company_id_map = await _build_company_id_map(pool)
# Re-resolve extractor config (picks up active variant swaps)
try:
resolved_config = await resolver.resolve("document-extractor")
if resolved_config is not None:
new_ollama_cfg = _build_ollama_config_from_resolved(
resolved_config, config.ollama,
new_resolved = await resolver.resolve("document-extractor")
if new_resolved is not None:
new_provider = _get_provider(new_resolved)
new_cfg = build_config_from_resolved(
new_resolved, config.ollama, config.vllm,
)
if new_ollama_cfg.model != ollama._config.model:
old_provider = extractor_provider
provider_changed = new_provider != extractor_provider
model_changed = new_cfg.model != extractor_client._config.model
if provider_changed or model_changed:
logger.info(
"Extractor config changed: model=%s variant=%s",
resolved_config.model_name, resolved_config.variant_id,
"Extractor provider switch: old_provider=%s new_provider=%s "
"model=%s variant=%s",
old_provider, new_provider,
new_resolved.model_name, new_resolved.variant_id,
)
await ollama.close()
ollama = OllamaClient(new_ollama_cfg)
await extractor_client.close()
extractor_client = build_llm_client(
new_resolved, config.ollama, config.vllm,
)
extractor_provider = new_provider
else:
ollama._config = new_ollama_cfg
# Same provider and model — just update config in-place
extractor_client._config = new_cfg # type: ignore[assignment]
resolved_config = new_resolved
except Exception:
logger.warning("Failed to refresh extractor config", exc_info=True)
# Re-resolve event classifier config
try:
classifier_resolved = await resolver.resolve("event-classifier")
if classifier_resolved is not None:
new_cls_cfg = _build_ollama_config_from_resolved(
classifier_resolved, config.ollama,
new_cls_resolved = await resolver.resolve("event-classifier")
if new_cls_resolved is not None:
new_cls_provider = _get_provider(new_cls_resolved)
new_cls_cfg = build_config_from_resolved(
new_cls_resolved, config.ollama, config.vllm,
)
if new_cls_cfg.model != classifier_ollama._config.model:
old_cls_provider = classifier_provider
cls_provider_changed = new_cls_provider != classifier_provider
cls_model_changed = new_cls_cfg.model != classifier_client._config.model
if cls_provider_changed or cls_model_changed:
logger.info(
"Event classifier config changed: model=%s variant=%s",
classifier_resolved.model_name, classifier_resolved.variant_id,
"Classifier provider switch: old_provider=%s new_provider=%s "
"model=%s variant=%s",
old_cls_provider, new_cls_provider,
new_cls_resolved.model_name, new_cls_resolved.variant_id,
)
if classifier_ollama is not ollama:
await classifier_ollama.close()
classifier_ollama = OllamaClient(new_cls_cfg)
elif classifier_ollama is ollama and new_cls_cfg.model != ollama._config.model:
classifier_ollama = OllamaClient(new_cls_cfg)
if classifier_client is not extractor_client:
await classifier_client.close()
classifier_client = build_llm_client(
new_cls_resolved, config.ollama, config.vllm,
)
classifier_provider = new_cls_provider
elif classifier_client is extractor_client and new_cls_cfg.model != extractor_client._config.model:
classifier_client = build_llm_client(
new_cls_resolved, config.ollama, config.vllm,
)
classifier_provider = new_cls_provider
classifier_resolved = new_cls_resolved
except Exception:
logger.warning("Failed to refresh event-classifier config", exc_info=True)
@@ -490,7 +550,7 @@ async def main() -> None:
await _process_macro_classification(
pool=pool,
minio_client=minio_client,
ollama=classifier_ollama,
ollama=classifier_client,
redis_client=redis_client,
document_id=document_id,
text=text,
@@ -529,7 +589,7 @@ async def main() -> None:
# Pass all tracked tickers so the model can identify any mentioned companies
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
extraction_response = await ollama.extract(
extraction_response = await extractor_client.extract(
extraction_text,
document_id=document_id,
known_tickers=all_tickers,
+177
View File
@@ -0,0 +1,177 @@
"""vLLM client for OpenAI-compatible chat completions.
Sends structured extraction requests to a remote vLLM server via the
``/v1/chat/completions`` endpoint. Reuses the same markdown-fence
stripping, JSON repair, and error-string conventions as OllamaClient
so that ``_is_retryable()`` works without modification.
Requirements: 2.12.10, 7.17.4
"""
from __future__ import annotations
import logging
import time
import httpx
from services.extractor.client import (
ExtractionAttempt,
_is_retryable,
_repair_json,
_strip_markdown_fences,
)
from services.extractor.schemas import validate_extraction
from services.shared.config import VLLMConfig
logger = logging.getLogger("vllm_client")
class VLLMClient:
"""Async client for vLLM OpenAI-compatible chat completions.
Satisfies the ``LLMClient`` protocol defined in
``services.shared.llm_protocol``.
"""
_config: VLLMConfig
_http: httpx.AsyncClient
_owns_client: bool
def __init__(
self,
config: VLLMConfig,
http_client: httpx.AsyncClient | None = None,
) -> None:
self._config = config
self._owns_client = http_client is None
self._http = http_client or httpx.AsyncClient(
timeout=httpx.Timeout(config.timeout, read=config.timeout),
)
# ------------------------------------------------------------------
# LLMClient protocol
# ------------------------------------------------------------------
async def call_llm(
self,
prompts: dict[str, str],
json_schema: dict[str, object],
document_text: str = "",
) -> ExtractionAttempt:
"""Send a chat completion request to the vLLM server.
Builds an OpenAI-compatible payload, posts to
``/v1/chat/completions``, and parses the response through the
same markdown-fence / JSON-repair pipeline used by OllamaClient.
"""
attempt = ExtractionAttempt(model=self._config.model)
start = time.monotonic()
headers: dict[str, str] = {}
if self._config.api_key:
headers["Authorization"] = f"Bearer {self._config.api_key}"
payload: dict[str, object] = {
"model": self._config.model,
"messages": [
{"role": "system", "content": prompts["system"]},
{"role": "user", "content": prompts["user"]},
],
"max_tokens": self._config.max_tokens,
"temperature": self._config.temperature,
"response_format": {"type": "json_object"},
}
url = f"{self._config.base_url}/v1/chat/completions"
logger.info(
"vLLM POST %s model=%s input_chars=%d",
url,
self._config.model,
len(prompts.get("user", "")),
)
try:
resp = await self._http.post(url, json=payload, headers=headers)
resp.raise_for_status()
except httpx.TimeoutException:
attempt.error = "timeout"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPStatusError as exc:
attempt.error = f"http_{exc.response.status_code}"
attempt.retryable = _is_retryable(attempt.error)
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPError as exc:
attempt.error = f"connection_error: {exc}"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
attempt.duration_ms = int((time.monotonic() - start) * 1000)
# --- Parse the OpenAI-compatible response ---
try:
data = resp.json()
except Exception:
attempt.error = "invalid_response_json"
attempt.raw_output = resp.text[:2000]
return attempt
choices = data.get("choices") or []
if not choices:
attempt.error = "empty_model_response"
return attempt
content = (
choices[0].get("message", {}).get("content", "")
if isinstance(choices[0], dict)
else ""
)
attempt.raw_output = content
if not content:
attempt.error = "empty_model_response"
return attempt
# Strip markdown fences if present
content = _strip_markdown_fences(content)
# Repair malformed JSON
content = _repair_json(content)
# Validate against extraction schema
attempt.validation = validate_extraction(content, document_text=document_text)
if not attempt.validation.valid:
attempt.error = "; ".join(attempt.validation.errors)
return attempt
async def close(self) -> None:
"""Release the underlying ``httpx.AsyncClient`` if we own it."""
if self._owns_client:
await self._http.aclose()
# ------------------------------------------------------------------
# Standalone health check
# ------------------------------------------------------------------
async def check_vllm_health(base_url: str, timeout: float = 10.0) -> bool:
"""Verify the vLLM server is reachable by querying ``/v1/models``.
Returns ``True`` when the server responds with HTTP 200, ``False``
otherwise. Logs INFO on success and WARNING on failure.
Requirements: 7.17.4
"""
url = f"{base_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
resp = await client.get(url)
resp.raise_for_status()
logger.info("vLLM health check passed: %s", url)
return True
except Exception as exc:
logger.warning("vLLM health check failed for %s: %s", url, exc)
return False
+31
View File
@@ -54,6 +54,24 @@ class OllamaConfig:
context_window: int = 0 # Ollama num_ctx; 0 = use model default
@dataclass
class VLLMConfig:
"""Configuration for the remote vLLM inference server.
Requirements: 3.1, 3.2
"""
base_url: str = "http://192.168.42.254:8000"
model: str = "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
timeout: int = 120
max_retries: int = 2
retry_base_delay: float = 1.0
retry_max_delay: float = 10.0
retry_backoff_multiplier: float = 2.0
max_tokens: int = 32768
temperature: float = 0.7
api_key: str = "" # Optional, for authenticated vLLM deployments
@dataclass
class TrinoConfig:
host: str = "localhost"
@@ -217,6 +235,7 @@ class AppConfig:
redis: RedisConfig = field(default_factory=RedisConfig)
minio: MinioConfig = field(default_factory=MinioConfig)
ollama: OllamaConfig = field(default_factory=OllamaConfig)
vllm: VLLMConfig = field(default_factory=VLLMConfig)
trino: TrinoConfig = field(default_factory=TrinoConfig)
market_data: MarketDataConfig = field(default_factory=MarketDataConfig)
broker: BrokerConfig = field(default_factory=BrokerConfig)
@@ -260,6 +279,18 @@ def load_config() -> AppConfig:
retry_max_delay=float(os.getenv("OLLAMA_RETRY_MAX_DELAY", "10.0")),
retry_backoff_multiplier=float(os.getenv("OLLAMA_RETRY_BACKOFF_MULTIPLIER", "2.0")),
),
vllm=VLLMConfig(
base_url=os.getenv("VLLM_BASE_URL", "http://192.168.42.254:8000"),
model=os.getenv("VLLM_MODEL", "RedHatAI/Qwen3.6-35B-A3B-NVFP4"),
timeout=int(os.getenv("VLLM_TIMEOUT", "120")),
max_retries=int(os.getenv("VLLM_MAX_RETRIES", "2")),
retry_base_delay=float(os.getenv("VLLM_RETRY_BASE_DELAY", "1.0")),
retry_max_delay=float(os.getenv("VLLM_RETRY_MAX_DELAY", "10.0")),
retry_backoff_multiplier=float(os.getenv("VLLM_RETRY_BACKOFF_MULTIPLIER", "2.0")),
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", "32768")),
temperature=float(os.getenv("VLLM_TEMPERATURE", "0.7")),
api_key=os.getenv("VLLM_API_KEY", ""),
),
trino=TrinoConfig(
host=os.getenv("TRINO_HOST", "localhost"),
port=int(os.getenv("TRINO_PORT", "8080")),
+44
View File
@@ -0,0 +1,44 @@
"""LLM client protocol for provider abstraction.
Defines the structural interface that both OllamaClient and VLLMClient
must satisfy, using typing.Protocol for duck-typing compatibility.
Requirements: 1.1, 1.2
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol, runtime_checkable
if TYPE_CHECKING:
from services.extractor.client import ExtractionAttempt
@runtime_checkable
class LLMClient(Protocol):
"""Protocol defining the contract for LLM inference clients.
Both OllamaClient and VLLMClient satisfy this protocol via
structural subtyping — no inheritance required.
"""
async def call_llm(
self,
prompts: dict[str, str],
json_schema: dict[str, object],
document_text: str = "",
) -> ExtractionAttempt:
"""Send a chat completion request and return an extraction attempt.
Args:
prompts: Dict with 'system' and 'user' prompt strings.
json_schema: JSON schema hint for structured output.
document_text: Optional raw document text for context.
Returns:
An ExtractionAttempt with raw output, validation, and error info.
"""
...
async def close(self) -> None:
"""Release underlying HTTP resources."""
...