feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference - VLLMClient for OpenAI-compatible /v1/chat/completions API - LLM client factory with provider routing (ollama/vllm) - VLLMConfig with VLLM_* environment variable loading - Updated extractor worker with health check and provider switching - Updated event classifier to use LLMClient protocol - Helm values for vLLM configuration - 18 unit tests + 6 property-based tests - Full backward compatibility preserved
This commit is contained in:
@@ -155,6 +155,19 @@ class OllamaClient:
|
||||
if self._owns_client:
|
||||
await self._http.aclose()
|
||||
|
||||
async def call_llm(
|
||||
self,
|
||||
prompts: dict[str, str],
|
||||
json_schema: dict[str, object],
|
||||
document_text: str = "",
|
||||
) -> ExtractionAttempt:
|
||||
"""Public LLM client interface — delegates to _call_ollama().
|
||||
|
||||
Satisfies the LLMClient protocol so OllamaClient can be used
|
||||
interchangeably with VLLMClient.
|
||||
"""
|
||||
return await self._call_ollama(prompts, json_schema, document_text)
|
||||
|
||||
async def extract(
|
||||
self,
|
||||
document_text: str,
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Event classifier module for macro news articles.
|
||||
|
||||
Classifies global/geopolitical news articles into structured GlobalEvent
|
||||
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
|
||||
existing OllamaClient for inference and retry logic.
|
||||
objects using an LLM client (Ollama or vLLM) with a dedicated prompt and
|
||||
JSON schema. Uses the LLMClient protocol for provider-agnostic inference
|
||||
and retry logic.
|
||||
|
||||
Persists classification prompts, raw outputs, and final events to MinIO
|
||||
and PostgreSQL for audit and downstream interpolation.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
||||
Requirements: 1.4, 2.1, 2.2, 2.3, 2.4, 2.5, 6.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -24,6 +25,8 @@ import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||||
from services.shared.config import VLLMConfig
|
||||
from services.shared.llm_protocol import LLMClient
|
||||
from services.shared.schemas import (
|
||||
EstimatedDuration,
|
||||
ImpactType,
|
||||
@@ -281,6 +284,7 @@ def _parse_classification_response(
|
||||
raw_json: str,
|
||||
document_id: str,
|
||||
model_name: str,
|
||||
provider: str = "ollama",
|
||||
) -> GlobalEvent:
|
||||
"""Parse raw Ollama JSON output into a GlobalEvent.
|
||||
|
||||
@@ -345,7 +349,7 @@ def _parse_classification_response(
|
||||
confidence=confidence,
|
||||
source_document_id=document_id,
|
||||
model_metadata=ModelMetadata(
|
||||
provider="ollama",
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
prompt_version=PROMPT_VERSION,
|
||||
schema_version=SCHEMA_VERSION,
|
||||
@@ -479,21 +483,21 @@ async def persist_global_event(
|
||||
async def classify_global_event(
|
||||
normalized_text: str,
|
||||
document_id: str,
|
||||
ollama_client: Any,
|
||||
client: LLMClient,
|
||||
*,
|
||||
pool: asyncpg.Pool | None = None,
|
||||
minio_client: Minio | None = None,
|
||||
) -> GlobalEvent:
|
||||
"""Classify a macro news article into a GlobalEvent using Ollama.
|
||||
"""Classify a macro news article into a GlobalEvent using an LLM.
|
||||
|
||||
Uses the existing OllamaClient's streaming infrastructure with a
|
||||
dedicated event classification prompt and JSON schema. Follows the
|
||||
same retry policy as document extraction.
|
||||
Uses the LLMClient protocol's call_llm() method with a dedicated
|
||||
event classification prompt and JSON schema. Follows the same retry
|
||||
policy as document extraction.
|
||||
|
||||
Resolves runtime config for the "event-classifier" agent slug from
|
||||
the database, preferring an active variant's model_name and
|
||||
system_prompt if one exists. Falls back to the OllamaClient's
|
||||
existing config if resolution fails.
|
||||
system_prompt if one exists. Falls back to the client's existing
|
||||
config if resolution fails.
|
||||
|
||||
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
||||
when the respective clients are provided.
|
||||
@@ -501,7 +505,7 @@ async def classify_global_event(
|
||||
Args:
|
||||
normalized_text: Cleaned text content of the macro article.
|
||||
document_id: UUID of the source document.
|
||||
ollama_client: An OllamaClient instance (from services.extractor.client).
|
||||
client: An LLMClient instance (OllamaClient or VLLMClient).
|
||||
pool: Optional asyncpg pool for PostgreSQL persistence.
|
||||
minio_client: Optional MinIO client for artifact persistence.
|
||||
|
||||
@@ -528,7 +532,10 @@ async def classify_global_event(
|
||||
|
||||
prompts = build_event_classification_prompt(normalized_text)
|
||||
json_schema = get_event_json_schema()
|
||||
model_name = ollama_client._config.model
|
||||
model_name = client._config.model
|
||||
|
||||
# Detect provider from client config type
|
||||
provider = "vllm" if isinstance(client._config, VLLMConfig) else "ollama"
|
||||
|
||||
# Override model_name and system_prompt from resolved config
|
||||
if resolved is not None:
|
||||
@@ -562,16 +569,16 @@ async def classify_global_event(
|
||||
except Exception:
|
||||
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
||||
|
||||
# Call Ollama using the client's internal _call_ollama method
|
||||
# Call LLM using the client's call_llm method
|
||||
# We reuse the retry logic pattern from OllamaClient.extract()
|
||||
max_retries = ollama_client._max_retries
|
||||
max_retries = client._config.max_retries
|
||||
if resolved is not None:
|
||||
max_retries = resolved.max_retries
|
||||
last_error: str | None = None
|
||||
raw_output = ""
|
||||
|
||||
for attempt_num in range(max_retries + 1):
|
||||
attempt = await ollama_client._call_ollama(prompts, json_schema)
|
||||
attempt = await client.call_llm(prompts, json_schema)
|
||||
raw_output = attempt.raw_output
|
||||
|
||||
# _call_ollama validates against the *extraction* schema, which
|
||||
@@ -581,7 +588,7 @@ async def classify_global_event(
|
||||
# Try to parse the response
|
||||
try:
|
||||
event = _parse_classification_response(
|
||||
raw_output, document_id, model_name,
|
||||
raw_output, document_id, model_name, provider=provider,
|
||||
)
|
||||
|
||||
# Persist result to MinIO
|
||||
@@ -648,10 +655,10 @@ async def classify_global_event(
|
||||
|
||||
# Retry with backoff
|
||||
if attempt_num < max_retries:
|
||||
delay = ollama_client._base_delay * (
|
||||
ollama_client._backoff_multiplier ** attempt_num
|
||||
delay = client._config.retry_base_delay * (
|
||||
client._config.retry_backoff_multiplier ** attempt_num
|
||||
)
|
||||
delay = min(delay, ollama_client._max_delay)
|
||||
delay = min(delay, client._config.retry_max_delay)
|
||||
logger.warning(
|
||||
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""LLM client factory for provider-based routing.
|
||||
|
||||
Returns the appropriate LLM client (OllamaClient or VLLMClient) based on
|
||||
the resolved ``model_provider`` from the agent config. Falls back to
|
||||
OllamaClient for unknown or missing providers.
|
||||
|
||||
Requirements: 3.4, 3.5, 3.6, 9.5
|
||||
Design: LLM Client Factory
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.vllm_client import VLLMClient
|
||||
from services.shared.agent_config import ResolvedAgentConfig
|
||||
from services.shared.config import OllamaConfig, VLLMConfig
|
||||
from services.shared.llm_protocol import LLMClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Providers that map to OllamaClient (including empty / None).
|
||||
_OLLAMA_PROVIDERS = frozenset({"ollama", "", None})
|
||||
|
||||
|
||||
def build_config_from_resolved(
|
||||
resolved: ResolvedAgentConfig,
|
||||
base_ollama: OllamaConfig,
|
||||
base_vllm: VLLMConfig,
|
||||
) -> OllamaConfig | VLLMConfig:
|
||||
"""Build a provider-specific config from a resolved agent config.
|
||||
|
||||
Merges the resolved agent-level overrides (model_name, timeout, retries,
|
||||
max_tokens, context_window) with the base environment config (base_url,
|
||||
retry delays, provider-specific defaults).
|
||||
|
||||
Args:
|
||||
resolved: Runtime config resolved from the database.
|
||||
base_ollama: Base OllamaConfig loaded from environment variables.
|
||||
base_vllm: Base VLLMConfig loaded from environment variables.
|
||||
|
||||
Returns:
|
||||
An ``OllamaConfig`` or ``VLLMConfig`` depending on the provider.
|
||||
"""
|
||||
provider = (resolved.model_provider or "").strip().lower()
|
||||
|
||||
if provider == "vllm":
|
||||
return VLLMConfig(
|
||||
base_url=base_vllm.base_url,
|
||||
model=resolved.model_name,
|
||||
timeout=resolved.timeout_seconds,
|
||||
max_retries=resolved.max_retries,
|
||||
retry_base_delay=base_vllm.retry_base_delay,
|
||||
retry_max_delay=base_vllm.retry_max_delay,
|
||||
retry_backoff_multiplier=base_vllm.retry_backoff_multiplier,
|
||||
max_tokens=resolved.max_tokens,
|
||||
temperature=base_vllm.temperature,
|
||||
api_key=base_vllm.api_key,
|
||||
)
|
||||
|
||||
# Default: Ollama config (covers "ollama", "", None, and unknown)
|
||||
if provider not in _OLLAMA_PROVIDERS:
|
||||
logger.warning(
|
||||
"Unknown model_provider %r for agent %s — treating as ollama",
|
||||
resolved.model_provider,
|
||||
resolved.agent_id,
|
||||
)
|
||||
|
||||
return OllamaConfig(
|
||||
base_url=base_ollama.base_url,
|
||||
model=resolved.model_name,
|
||||
timeout=resolved.timeout_seconds,
|
||||
max_retries=resolved.max_retries,
|
||||
retry_base_delay=base_ollama.retry_base_delay,
|
||||
retry_max_delay=base_ollama.retry_max_delay,
|
||||
retry_backoff_multiplier=base_ollama.retry_backoff_multiplier,
|
||||
max_tokens=resolved.max_tokens,
|
||||
stall_timeout=base_ollama.stall_timeout,
|
||||
loop_window=base_ollama.loop_window,
|
||||
loop_threshold=base_ollama.loop_threshold,
|
||||
context_window=resolved.context_window,
|
||||
)
|
||||
|
||||
|
||||
def build_llm_client(
|
||||
resolved: ResolvedAgentConfig | None,
|
||||
ollama_config: OllamaConfig,
|
||||
vllm_config: VLLMConfig,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> LLMClient:
|
||||
"""Return the appropriate LLM client based on the resolved provider.
|
||||
|
||||
Provider routing:
|
||||
- ``None`` / ``""`` / ``"ollama"`` → :class:`OllamaClient`
|
||||
- ``"vllm"`` → :class:`VLLMClient`
|
||||
- Unknown value → log warning, fall back to :class:`OllamaClient`
|
||||
|
||||
When *resolved* is ``None`` (DB lookup failed), the base
|
||||
``ollama_config`` is used directly.
|
||||
|
||||
Args:
|
||||
resolved: Resolved agent config (may be ``None``).
|
||||
ollama_config: Base OllamaConfig from environment.
|
||||
vllm_config: Base VLLMConfig from environment.
|
||||
http_client: Optional shared httpx client for testing.
|
||||
|
||||
Returns:
|
||||
An LLM client satisfying the :class:`LLMClient` protocol.
|
||||
"""
|
||||
if resolved is None:
|
||||
logger.info("No resolved agent config — defaulting to OllamaClient")
|
||||
return OllamaClient(ollama_config, http_client=http_client)
|
||||
|
||||
provider = (resolved.model_provider or "").strip().lower()
|
||||
|
||||
if provider == "vllm":
|
||||
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
|
||||
logger.info(
|
||||
"Building VLLMClient for agent %s (model=%s)",
|
||||
resolved.agent_id,
|
||||
cfg.model, # type: ignore[union-attr]
|
||||
)
|
||||
return VLLMClient(cfg, http_client=http_client) # type: ignore[arg-type]
|
||||
|
||||
if provider not in _OLLAMA_PROVIDERS:
|
||||
logger.warning(
|
||||
"Unknown model_provider %r for agent %s — falling back to OllamaClient",
|
||||
resolved.model_provider,
|
||||
resolved.agent_id,
|
||||
)
|
||||
|
||||
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
|
||||
return OllamaClient(cfg, http_client=http_client) # type: ignore[arg-type]
|
||||
+106
-46
@@ -15,11 +15,13 @@ from services.aggregation.interpolation import (
|
||||
filter_low_confidence_events,
|
||||
persist_macro_impact_records,
|
||||
)
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.event_classifier import classify_global_event
|
||||
from services.extractor.llm_factory import build_config_from_resolved, build_llm_client
|
||||
from services.extractor.vllm_client import check_vllm_health
|
||||
from services.extractor.worker import persist_extraction
|
||||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||||
from services.shared.config import OllamaConfig, load_config
|
||||
from services.shared.llm_protocol import LLMClient
|
||||
from services.shared.logging import inject_trace_context, setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_AGGREGATION,
|
||||
@@ -31,11 +33,22 @@ from services.shared.redis_keys import (
|
||||
logger = logging.getLogger("extractor_main")
|
||||
|
||||
|
||||
def _get_provider(resolved: ResolvedAgentConfig | None) -> str:
|
||||
"""Return the normalised provider string for a resolved config."""
|
||||
if resolved is None:
|
||||
return "ollama"
|
||||
return (resolved.model_provider or "").strip().lower() or "ollama"
|
||||
|
||||
|
||||
def _build_ollama_config_from_resolved(
|
||||
resolved: ResolvedAgentConfig,
|
||||
base_config: OllamaConfig,
|
||||
) -> OllamaConfig:
|
||||
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings."""
|
||||
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings.
|
||||
|
||||
Kept for backward compatibility — the factory's ``build_config_from_resolved``
|
||||
is now the primary path.
|
||||
"""
|
||||
return OllamaConfig(
|
||||
base_url=base_config.base_url,
|
||||
model=resolved.model_name,
|
||||
@@ -239,7 +252,7 @@ async def _process_macro_classification(
|
||||
*,
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
ollama: OllamaClient,
|
||||
ollama: LLMClient,
|
||||
redis_client: aioredis.Redis,
|
||||
document_id: str,
|
||||
text: str,
|
||||
@@ -258,7 +271,7 @@ async def _process_macro_classification(
|
||||
event = await classify_global_event(
|
||||
normalized_text=text,
|
||||
document_id=document_id,
|
||||
ollama_client=ollama,
|
||||
client=ollama,
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
)
|
||||
@@ -329,48 +342,69 @@ async def main() -> None:
|
||||
# Resolve extractor config from DB (active variant override + TTL cache)
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||
resolved_config: ResolvedAgentConfig | None = None
|
||||
extractor_ollama_config = config.ollama
|
||||
extractor_provider = "ollama"
|
||||
try:
|
||||
resolved_config = await resolver.resolve("document-extractor")
|
||||
if resolved_config is not None:
|
||||
extractor_ollama_config = _build_ollama_config_from_resolved(
|
||||
resolved_config, config.ollama,
|
||||
)
|
||||
extractor_provider = _get_provider(resolved_config)
|
||||
logger.info(
|
||||
"Extractor using resolved config: model=%s variant=%s",
|
||||
resolved_config.model_name, resolved_config.variant_id,
|
||||
"Extractor using resolved config: model=%s variant=%s provider=%s",
|
||||
resolved_config.model_name, resolved_config.variant_id, extractor_provider,
|
||||
)
|
||||
else:
|
||||
logger.info("No DB config for document-extractor — using env defaults")
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
|
||||
|
||||
ollama = OllamaClient(extractor_ollama_config)
|
||||
# vLLM health check at startup when provider is vllm (Requirement 7.1–7.3)
|
||||
if extractor_provider == "vllm":
|
||||
healthy = await check_vllm_health(config.vllm.base_url)
|
||||
if not healthy:
|
||||
logger.warning(
|
||||
"vLLM health check failed at startup — falling back to Ollama for extractor",
|
||||
)
|
||||
extractor_provider = "ollama"
|
||||
# Override resolved config provider so factory builds OllamaClient
|
||||
resolved_config = None
|
||||
|
||||
extractor_client: LLMClient = build_llm_client(
|
||||
resolved_config, config.ollama, config.vllm,
|
||||
)
|
||||
|
||||
# Resolve event classifier config separately (may use different model)
|
||||
classifier_resolved: ResolvedAgentConfig | None = None
|
||||
classifier_ollama_config = config.ollama
|
||||
classifier_provider = "ollama"
|
||||
try:
|
||||
classifier_resolved = await resolver.resolve("event-classifier")
|
||||
if classifier_resolved is not None:
|
||||
classifier_ollama_config = _build_ollama_config_from_resolved(
|
||||
classifier_resolved, config.ollama,
|
||||
)
|
||||
classifier_provider = _get_provider(classifier_resolved)
|
||||
logger.info(
|
||||
"Event classifier using resolved config: model=%s variant=%s",
|
||||
classifier_resolved.model_name, classifier_resolved.variant_id,
|
||||
"Event classifier using resolved config: model=%s variant=%s provider=%s",
|
||||
classifier_resolved.model_name, classifier_resolved.variant_id, classifier_provider,
|
||||
)
|
||||
else:
|
||||
logger.info("No DB config for event-classifier — using extractor config")
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
|
||||
|
||||
# Use a separate OllamaClient for the classifier if it has a different model
|
||||
classifier_ollama: OllamaClient
|
||||
if classifier_ollama_config.model != extractor_ollama_config.model:
|
||||
classifier_ollama = OllamaClient(classifier_ollama_config)
|
||||
# vLLM health check for classifier if it uses vllm and extractor didn't already check
|
||||
if classifier_provider == "vllm" and extractor_provider != "vllm":
|
||||
healthy = await check_vllm_health(config.vllm.base_url)
|
||||
if not healthy:
|
||||
logger.warning(
|
||||
"vLLM health check failed at startup — falling back to Ollama for classifier",
|
||||
)
|
||||
classifier_provider = "ollama"
|
||||
classifier_resolved = None
|
||||
|
||||
# Build classifier client — share with extractor when configs match
|
||||
classifier_client: LLMClient
|
||||
if classifier_resolved is not None or classifier_provider != extractor_provider:
|
||||
classifier_client = build_llm_client(
|
||||
classifier_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
else:
|
||||
classifier_ollama = ollama
|
||||
classifier_client = extractor_client
|
||||
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_EXTRACTION)
|
||||
@@ -441,40 +475,66 @@ async def main() -> None:
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
# Re-resolve extractor config (picks up active variant swaps)
|
||||
try:
|
||||
resolved_config = await resolver.resolve("document-extractor")
|
||||
if resolved_config is not None:
|
||||
new_ollama_cfg = _build_ollama_config_from_resolved(
|
||||
resolved_config, config.ollama,
|
||||
new_resolved = await resolver.resolve("document-extractor")
|
||||
if new_resolved is not None:
|
||||
new_provider = _get_provider(new_resolved)
|
||||
new_cfg = build_config_from_resolved(
|
||||
new_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
if new_ollama_cfg.model != ollama._config.model:
|
||||
old_provider = extractor_provider
|
||||
provider_changed = new_provider != extractor_provider
|
||||
model_changed = new_cfg.model != extractor_client._config.model
|
||||
|
||||
if provider_changed or model_changed:
|
||||
logger.info(
|
||||
"Extractor config changed: model=%s variant=%s",
|
||||
resolved_config.model_name, resolved_config.variant_id,
|
||||
"Extractor provider switch: old_provider=%s new_provider=%s "
|
||||
"model=%s variant=%s",
|
||||
old_provider, new_provider,
|
||||
new_resolved.model_name, new_resolved.variant_id,
|
||||
)
|
||||
await ollama.close()
|
||||
ollama = OllamaClient(new_ollama_cfg)
|
||||
await extractor_client.close()
|
||||
extractor_client = build_llm_client(
|
||||
new_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
extractor_provider = new_provider
|
||||
else:
|
||||
ollama._config = new_ollama_cfg
|
||||
# Same provider and model — just update config in-place
|
||||
extractor_client._config = new_cfg # type: ignore[assignment]
|
||||
resolved_config = new_resolved
|
||||
except Exception:
|
||||
logger.warning("Failed to refresh extractor config", exc_info=True)
|
||||
|
||||
# Re-resolve event classifier config
|
||||
try:
|
||||
classifier_resolved = await resolver.resolve("event-classifier")
|
||||
if classifier_resolved is not None:
|
||||
new_cls_cfg = _build_ollama_config_from_resolved(
|
||||
classifier_resolved, config.ollama,
|
||||
new_cls_resolved = await resolver.resolve("event-classifier")
|
||||
if new_cls_resolved is not None:
|
||||
new_cls_provider = _get_provider(new_cls_resolved)
|
||||
new_cls_cfg = build_config_from_resolved(
|
||||
new_cls_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
if new_cls_cfg.model != classifier_ollama._config.model:
|
||||
old_cls_provider = classifier_provider
|
||||
cls_provider_changed = new_cls_provider != classifier_provider
|
||||
cls_model_changed = new_cls_cfg.model != classifier_client._config.model
|
||||
|
||||
if cls_provider_changed or cls_model_changed:
|
||||
logger.info(
|
||||
"Event classifier config changed: model=%s variant=%s",
|
||||
classifier_resolved.model_name, classifier_resolved.variant_id,
|
||||
"Classifier provider switch: old_provider=%s new_provider=%s "
|
||||
"model=%s variant=%s",
|
||||
old_cls_provider, new_cls_provider,
|
||||
new_cls_resolved.model_name, new_cls_resolved.variant_id,
|
||||
)
|
||||
if classifier_ollama is not ollama:
|
||||
await classifier_ollama.close()
|
||||
classifier_ollama = OllamaClient(new_cls_cfg)
|
||||
elif classifier_ollama is ollama and new_cls_cfg.model != ollama._config.model:
|
||||
classifier_ollama = OllamaClient(new_cls_cfg)
|
||||
if classifier_client is not extractor_client:
|
||||
await classifier_client.close()
|
||||
classifier_client = build_llm_client(
|
||||
new_cls_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
classifier_provider = new_cls_provider
|
||||
elif classifier_client is extractor_client and new_cls_cfg.model != extractor_client._config.model:
|
||||
classifier_client = build_llm_client(
|
||||
new_cls_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
classifier_provider = new_cls_provider
|
||||
classifier_resolved = new_cls_resolved
|
||||
except Exception:
|
||||
logger.warning("Failed to refresh event-classifier config", exc_info=True)
|
||||
|
||||
@@ -490,7 +550,7 @@ async def main() -> None:
|
||||
await _process_macro_classification(
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
ollama=classifier_ollama,
|
||||
ollama=classifier_client,
|
||||
redis_client=redis_client,
|
||||
document_id=document_id,
|
||||
text=text,
|
||||
@@ -529,7 +589,7 @@ async def main() -> None:
|
||||
|
||||
# Pass all tracked tickers so the model can identify any mentioned companies
|
||||
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
||||
extraction_response = await ollama.extract(
|
||||
extraction_response = await extractor_client.extract(
|
||||
extraction_text,
|
||||
document_id=document_id,
|
||||
known_tickers=all_tickers,
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
"""vLLM client for OpenAI-compatible chat completions.
|
||||
|
||||
Sends structured extraction requests to a remote vLLM server via the
|
||||
``/v1/chat/completions`` endpoint. Reuses the same markdown-fence
|
||||
stripping, JSON repair, and error-string conventions as OllamaClient
|
||||
so that ``_is_retryable()`` works without modification.
|
||||
|
||||
Requirements: 2.1–2.10, 7.1–7.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import httpx
|
||||
|
||||
from services.extractor.client import (
|
||||
ExtractionAttempt,
|
||||
_is_retryable,
|
||||
_repair_json,
|
||||
_strip_markdown_fences,
|
||||
)
|
||||
from services.extractor.schemas import validate_extraction
|
||||
from services.shared.config import VLLMConfig
|
||||
|
||||
logger = logging.getLogger("vllm_client")
|
||||
|
||||
|
||||
class VLLMClient:
|
||||
"""Async client for vLLM OpenAI-compatible chat completions.
|
||||
|
||||
Satisfies the ``LLMClient`` protocol defined in
|
||||
``services.shared.llm_protocol``.
|
||||
"""
|
||||
|
||||
_config: VLLMConfig
|
||||
_http: httpx.AsyncClient
|
||||
_owns_client: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VLLMConfig,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._owns_client = http_client is None
|
||||
self._http = http_client or httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(config.timeout, read=config.timeout),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LLMClient protocol
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def call_llm(
|
||||
self,
|
||||
prompts: dict[str, str],
|
||||
json_schema: dict[str, object],
|
||||
document_text: str = "",
|
||||
) -> ExtractionAttempt:
|
||||
"""Send a chat completion request to the vLLM server.
|
||||
|
||||
Builds an OpenAI-compatible payload, posts to
|
||||
``/v1/chat/completions``, and parses the response through the
|
||||
same markdown-fence / JSON-repair pipeline used by OllamaClient.
|
||||
"""
|
||||
attempt = ExtractionAttempt(model=self._config.model)
|
||||
start = time.monotonic()
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
if self._config.api_key:
|
||||
headers["Authorization"] = f"Bearer {self._config.api_key}"
|
||||
|
||||
payload: dict[str, object] = {
|
||||
"model": self._config.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": prompts["system"]},
|
||||
{"role": "user", "content": prompts["user"]},
|
||||
],
|
||||
"max_tokens": self._config.max_tokens,
|
||||
"temperature": self._config.temperature,
|
||||
"response_format": {"type": "json_object"},
|
||||
}
|
||||
|
||||
url = f"{self._config.base_url}/v1/chat/completions"
|
||||
logger.info(
|
||||
"vLLM POST %s model=%s input_chars=%d",
|
||||
url,
|
||||
self._config.model,
|
||||
len(prompts.get("user", "")),
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=payload, headers=headers)
|
||||
resp.raise_for_status()
|
||||
except httpx.TimeoutException:
|
||||
attempt.error = "timeout"
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
except httpx.HTTPStatusError as exc:
|
||||
attempt.error = f"http_{exc.response.status_code}"
|
||||
attempt.retryable = _is_retryable(attempt.error)
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
except httpx.HTTPError as exc:
|
||||
attempt.error = f"connection_error: {exc}"
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# --- Parse the OpenAI-compatible response ---
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
attempt.error = "invalid_response_json"
|
||||
attempt.raw_output = resp.text[:2000]
|
||||
return attempt
|
||||
|
||||
choices = data.get("choices") or []
|
||||
if not choices:
|
||||
attempt.error = "empty_model_response"
|
||||
return attempt
|
||||
|
||||
content = (
|
||||
choices[0].get("message", {}).get("content", "")
|
||||
if isinstance(choices[0], dict)
|
||||
else ""
|
||||
)
|
||||
attempt.raw_output = content
|
||||
|
||||
if not content:
|
||||
attempt.error = "empty_model_response"
|
||||
return attempt
|
||||
|
||||
# Strip markdown fences if present
|
||||
content = _strip_markdown_fences(content)
|
||||
|
||||
# Repair malformed JSON
|
||||
content = _repair_json(content)
|
||||
|
||||
# Validate against extraction schema
|
||||
attempt.validation = validate_extraction(content, document_text=document_text)
|
||||
if not attempt.validation.valid:
|
||||
attempt.error = "; ".join(attempt.validation.errors)
|
||||
|
||||
return attempt
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Release the underlying ``httpx.AsyncClient`` if we own it."""
|
||||
if self._owns_client:
|
||||
await self._http.aclose()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Standalone health check
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
async def check_vllm_health(base_url: str, timeout: float = 10.0) -> bool:
|
||||
"""Verify the vLLM server is reachable by querying ``/v1/models``.
|
||||
|
||||
Returns ``True`` when the server responds with HTTP 200, ``False``
|
||||
otherwise. Logs INFO on success and WARNING on failure.
|
||||
|
||||
Requirements: 7.1–7.4
|
||||
"""
|
||||
url = f"{base_url}/v1/models"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
logger.info("vLLM health check passed: %s", url)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("vLLM health check failed for %s: %s", url, exc)
|
||||
return False
|
||||
@@ -54,6 +54,24 @@ class OllamaConfig:
|
||||
context_window: int = 0 # Ollama num_ctx; 0 = use model default
|
||||
|
||||
|
||||
@dataclass
|
||||
class VLLMConfig:
|
||||
"""Configuration for the remote vLLM inference server.
|
||||
|
||||
Requirements: 3.1, 3.2
|
||||
"""
|
||||
base_url: str = "http://192.168.42.254:8000"
|
||||
model: str = "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||||
timeout: int = 120
|
||||
max_retries: int = 2
|
||||
retry_base_delay: float = 1.0
|
||||
retry_max_delay: float = 10.0
|
||||
retry_backoff_multiplier: float = 2.0
|
||||
max_tokens: int = 32768
|
||||
temperature: float = 0.7
|
||||
api_key: str = "" # Optional, for authenticated vLLM deployments
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrinoConfig:
|
||||
host: str = "localhost"
|
||||
@@ -217,6 +235,7 @@ class AppConfig:
|
||||
redis: RedisConfig = field(default_factory=RedisConfig)
|
||||
minio: MinioConfig = field(default_factory=MinioConfig)
|
||||
ollama: OllamaConfig = field(default_factory=OllamaConfig)
|
||||
vllm: VLLMConfig = field(default_factory=VLLMConfig)
|
||||
trino: TrinoConfig = field(default_factory=TrinoConfig)
|
||||
market_data: MarketDataConfig = field(default_factory=MarketDataConfig)
|
||||
broker: BrokerConfig = field(default_factory=BrokerConfig)
|
||||
@@ -260,6 +279,18 @@ def load_config() -> AppConfig:
|
||||
retry_max_delay=float(os.getenv("OLLAMA_RETRY_MAX_DELAY", "10.0")),
|
||||
retry_backoff_multiplier=float(os.getenv("OLLAMA_RETRY_BACKOFF_MULTIPLIER", "2.0")),
|
||||
),
|
||||
vllm=VLLMConfig(
|
||||
base_url=os.getenv("VLLM_BASE_URL", "http://192.168.42.254:8000"),
|
||||
model=os.getenv("VLLM_MODEL", "RedHatAI/Qwen3.6-35B-A3B-NVFP4"),
|
||||
timeout=int(os.getenv("VLLM_TIMEOUT", "120")),
|
||||
max_retries=int(os.getenv("VLLM_MAX_RETRIES", "2")),
|
||||
retry_base_delay=float(os.getenv("VLLM_RETRY_BASE_DELAY", "1.0")),
|
||||
retry_max_delay=float(os.getenv("VLLM_RETRY_MAX_DELAY", "10.0")),
|
||||
retry_backoff_multiplier=float(os.getenv("VLLM_RETRY_BACKOFF_MULTIPLIER", "2.0")),
|
||||
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", "32768")),
|
||||
temperature=float(os.getenv("VLLM_TEMPERATURE", "0.7")),
|
||||
api_key=os.getenv("VLLM_API_KEY", ""),
|
||||
),
|
||||
trino=TrinoConfig(
|
||||
host=os.getenv("TRINO_HOST", "localhost"),
|
||||
port=int(os.getenv("TRINO_PORT", "8080")),
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""LLM client protocol for provider abstraction.
|
||||
|
||||
Defines the structural interface that both OllamaClient and VLLMClient
|
||||
must satisfy, using typing.Protocol for duck-typing compatibility.
|
||||
|
||||
Requirements: 1.1, 1.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.extractor.client import ExtractionAttempt
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LLMClient(Protocol):
|
||||
"""Protocol defining the contract for LLM inference clients.
|
||||
|
||||
Both OllamaClient and VLLMClient satisfy this protocol via
|
||||
structural subtyping — no inheritance required.
|
||||
"""
|
||||
|
||||
async def call_llm(
|
||||
self,
|
||||
prompts: dict[str, str],
|
||||
json_schema: dict[str, object],
|
||||
document_text: str = "",
|
||||
) -> ExtractionAttempt:
|
||||
"""Send a chat completion request and return an extraction attempt.
|
||||
|
||||
Args:
|
||||
prompts: Dict with 'system' and 'user' prompt strings.
|
||||
json_schema: JSON schema hint for structured output.
|
||||
document_text: Optional raw document text for context.
|
||||
|
||||
Returns:
|
||||
An ExtractionAttempt with raw output, validation, and error info.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Release underlying HTTP resources."""
|
||||
...
|
||||
Reference in New Issue
Block a user