117b693b19
- LLMClient Protocol for provider-agnostic inference - VLLMClient for OpenAI-compatible /v1/chat/completions API - LLM client factory with provider routing (ollama/vllm) - VLLMConfig with VLLM_* environment variable loading - Updated extractor worker with health check and provider switching - Updated event classifier to use LLMClient protocol - Helm values for vLLM configuration - 18 unit tests + 6 property-based tests - Full backward compatibility preserved
136 lines
4.8 KiB
Python
136 lines
4.8 KiB
Python
"""LLM client factory for provider-based routing.
|
|
|
|
Returns the appropriate LLM client (OllamaClient or VLLMClient) based on
|
|
the resolved ``model_provider`` from the agent config. Falls back to
|
|
OllamaClient for unknown or missing providers.
|
|
|
|
Requirements: 3.4, 3.5, 3.6, 9.5
|
|
Design: LLM Client Factory
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
import httpx
|
|
|
|
from services.extractor.client import OllamaClient
|
|
from services.extractor.vllm_client import VLLMClient
|
|
from services.shared.agent_config import ResolvedAgentConfig
|
|
from services.shared.config import OllamaConfig, VLLMConfig
|
|
from services.shared.llm_protocol import LLMClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Providers that map to OllamaClient (including empty / None).
|
|
_OLLAMA_PROVIDERS = frozenset({"ollama", "", None})
|
|
|
|
|
|
def build_config_from_resolved(
|
|
resolved: ResolvedAgentConfig,
|
|
base_ollama: OllamaConfig,
|
|
base_vllm: VLLMConfig,
|
|
) -> OllamaConfig | VLLMConfig:
|
|
"""Build a provider-specific config from a resolved agent config.
|
|
|
|
Merges the resolved agent-level overrides (model_name, timeout, retries,
|
|
max_tokens, context_window) with the base environment config (base_url,
|
|
retry delays, provider-specific defaults).
|
|
|
|
Args:
|
|
resolved: Runtime config resolved from the database.
|
|
base_ollama: Base OllamaConfig loaded from environment variables.
|
|
base_vllm: Base VLLMConfig loaded from environment variables.
|
|
|
|
Returns:
|
|
An ``OllamaConfig`` or ``VLLMConfig`` depending on the provider.
|
|
"""
|
|
provider = (resolved.model_provider or "").strip().lower()
|
|
|
|
if provider == "vllm":
|
|
return VLLMConfig(
|
|
base_url=base_vllm.base_url,
|
|
model=resolved.model_name,
|
|
timeout=resolved.timeout_seconds,
|
|
max_retries=resolved.max_retries,
|
|
retry_base_delay=base_vllm.retry_base_delay,
|
|
retry_max_delay=base_vllm.retry_max_delay,
|
|
retry_backoff_multiplier=base_vllm.retry_backoff_multiplier,
|
|
max_tokens=resolved.max_tokens,
|
|
temperature=base_vllm.temperature,
|
|
api_key=base_vllm.api_key,
|
|
)
|
|
|
|
# Default: Ollama config (covers "ollama", "", None, and unknown)
|
|
if provider not in _OLLAMA_PROVIDERS:
|
|
logger.warning(
|
|
"Unknown model_provider %r for agent %s — treating as ollama",
|
|
resolved.model_provider,
|
|
resolved.agent_id,
|
|
)
|
|
|
|
return OllamaConfig(
|
|
base_url=base_ollama.base_url,
|
|
model=resolved.model_name,
|
|
timeout=resolved.timeout_seconds,
|
|
max_retries=resolved.max_retries,
|
|
retry_base_delay=base_ollama.retry_base_delay,
|
|
retry_max_delay=base_ollama.retry_max_delay,
|
|
retry_backoff_multiplier=base_ollama.retry_backoff_multiplier,
|
|
max_tokens=resolved.max_tokens,
|
|
stall_timeout=base_ollama.stall_timeout,
|
|
loop_window=base_ollama.loop_window,
|
|
loop_threshold=base_ollama.loop_threshold,
|
|
context_window=resolved.context_window,
|
|
)
|
|
|
|
|
|
def build_llm_client(
|
|
resolved: ResolvedAgentConfig | None,
|
|
ollama_config: OllamaConfig,
|
|
vllm_config: VLLMConfig,
|
|
http_client: httpx.AsyncClient | None = None,
|
|
) -> LLMClient:
|
|
"""Return the appropriate LLM client based on the resolved provider.
|
|
|
|
Provider routing:
|
|
- ``None`` / ``""`` / ``"ollama"`` → :class:`OllamaClient`
|
|
- ``"vllm"`` → :class:`VLLMClient`
|
|
- Unknown value → log warning, fall back to :class:`OllamaClient`
|
|
|
|
When *resolved* is ``None`` (DB lookup failed), the base
|
|
``ollama_config`` is used directly.
|
|
|
|
Args:
|
|
resolved: Resolved agent config (may be ``None``).
|
|
ollama_config: Base OllamaConfig from environment.
|
|
vllm_config: Base VLLMConfig from environment.
|
|
http_client: Optional shared httpx client for testing.
|
|
|
|
Returns:
|
|
An LLM client satisfying the :class:`LLMClient` protocol.
|
|
"""
|
|
if resolved is None:
|
|
logger.info("No resolved agent config — defaulting to OllamaClient")
|
|
return OllamaClient(ollama_config, http_client=http_client)
|
|
|
|
provider = (resolved.model_provider or "").strip().lower()
|
|
|
|
if provider == "vllm":
|
|
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
|
|
logger.info(
|
|
"Building VLLMClient for agent %s (model=%s)",
|
|
resolved.agent_id,
|
|
cfg.model, # type: ignore[union-attr]
|
|
)
|
|
return VLLMClient(cfg, http_client=http_client) # type: ignore[arg-type]
|
|
|
|
if provider not in _OLLAMA_PROVIDERS:
|
|
logger.warning(
|
|
"Unknown model_provider %r for agent %s — falling back to OllamaClient",
|
|
resolved.model_provider,
|
|
resolved.agent_id,
|
|
)
|
|
|
|
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
|
|
return OllamaClient(cfg, http_client=http_client) # type: ignore[arg-type]
|