Files
Celes Renata 117b693b19 feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
2026-04-23 08:17:23 +00:00

136 lines
4.8 KiB
Python

"""LLM client factory for provider-based routing.
Returns the appropriate LLM client (OllamaClient or VLLMClient) based on
the resolved ``model_provider`` from the agent config. Falls back to
OllamaClient for unknown or missing providers.
Requirements: 3.4, 3.5, 3.6, 9.5
Design: LLM Client Factory
"""
from __future__ import annotations
import logging
import httpx
from services.extractor.client import OllamaClient
from services.extractor.vllm_client import VLLMClient
from services.shared.agent_config import ResolvedAgentConfig
from services.shared.config import OllamaConfig, VLLMConfig
from services.shared.llm_protocol import LLMClient
logger = logging.getLogger(__name__)
# Providers that map to OllamaClient (including empty / None).
_OLLAMA_PROVIDERS = frozenset({"ollama", "", None})
def build_config_from_resolved(
resolved: ResolvedAgentConfig,
base_ollama: OllamaConfig,
base_vllm: VLLMConfig,
) -> OllamaConfig | VLLMConfig:
"""Build a provider-specific config from a resolved agent config.
Merges the resolved agent-level overrides (model_name, timeout, retries,
max_tokens, context_window) with the base environment config (base_url,
retry delays, provider-specific defaults).
Args:
resolved: Runtime config resolved from the database.
base_ollama: Base OllamaConfig loaded from environment variables.
base_vllm: Base VLLMConfig loaded from environment variables.
Returns:
An ``OllamaConfig`` or ``VLLMConfig`` depending on the provider.
"""
provider = (resolved.model_provider or "").strip().lower()
if provider == "vllm":
return VLLMConfig(
base_url=base_vllm.base_url,
model=resolved.model_name,
timeout=resolved.timeout_seconds,
max_retries=resolved.max_retries,
retry_base_delay=base_vllm.retry_base_delay,
retry_max_delay=base_vllm.retry_max_delay,
retry_backoff_multiplier=base_vllm.retry_backoff_multiplier,
max_tokens=resolved.max_tokens,
temperature=base_vllm.temperature,
api_key=base_vllm.api_key,
)
# Default: Ollama config (covers "ollama", "", None, and unknown)
if provider not in _OLLAMA_PROVIDERS:
logger.warning(
"Unknown model_provider %r for agent %s — treating as ollama",
resolved.model_provider,
resolved.agent_id,
)
return OllamaConfig(
base_url=base_ollama.base_url,
model=resolved.model_name,
timeout=resolved.timeout_seconds,
max_retries=resolved.max_retries,
retry_base_delay=base_ollama.retry_base_delay,
retry_max_delay=base_ollama.retry_max_delay,
retry_backoff_multiplier=base_ollama.retry_backoff_multiplier,
max_tokens=resolved.max_tokens,
stall_timeout=base_ollama.stall_timeout,
loop_window=base_ollama.loop_window,
loop_threshold=base_ollama.loop_threshold,
context_window=resolved.context_window,
)
def build_llm_client(
resolved: ResolvedAgentConfig | None,
ollama_config: OllamaConfig,
vllm_config: VLLMConfig,
http_client: httpx.AsyncClient | None = None,
) -> LLMClient:
"""Return the appropriate LLM client based on the resolved provider.
Provider routing:
- ``None`` / ``""`` / ``"ollama"`` → :class:`OllamaClient`
- ``"vllm"`` → :class:`VLLMClient`
- Unknown value → log warning, fall back to :class:`OllamaClient`
When *resolved* is ``None`` (DB lookup failed), the base
``ollama_config`` is used directly.
Args:
resolved: Resolved agent config (may be ``None``).
ollama_config: Base OllamaConfig from environment.
vllm_config: Base VLLMConfig from environment.
http_client: Optional shared httpx client for testing.
Returns:
An LLM client satisfying the :class:`LLMClient` protocol.
"""
if resolved is None:
logger.info("No resolved agent config — defaulting to OllamaClient")
return OllamaClient(ollama_config, http_client=http_client)
provider = (resolved.model_provider or "").strip().lower()
if provider == "vllm":
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
logger.info(
"Building VLLMClient for agent %s (model=%s)",
resolved.agent_id,
cfg.model, # type: ignore[union-attr]
)
return VLLMClient(cfg, http_client=http_client) # type: ignore[arg-type]
if provider not in _OLLAMA_PROVIDERS:
logger.warning(
"Unknown model_provider %r for agent %s — falling back to OllamaClient",
resolved.model_provider,
resolved.agent_id,
)
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
return OllamaClient(cfg, http_client=http_client) # type: ignore[arg-type]