feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference - VLLMClient for OpenAI-compatible /v1/chat/completions API - LLM client factory with provider routing (ollama/vllm) - VLLMConfig with VLLM_* environment variable loading - Updated extractor worker with health check and provider switching - Updated event classifier to use LLMClient protocol - Helm values for vLLM configuration - 18 unit tests + 6 property-based tests - Full backward compatibility preserved
This commit is contained in:
+106
-46
@@ -15,11 +15,13 @@ from services.aggregation.interpolation import (
|
||||
filter_low_confidence_events,
|
||||
persist_macro_impact_records,
|
||||
)
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.event_classifier import classify_global_event
|
||||
from services.extractor.llm_factory import build_config_from_resolved, build_llm_client
|
||||
from services.extractor.vllm_client import check_vllm_health
|
||||
from services.extractor.worker import persist_extraction
|
||||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||||
from services.shared.config import OllamaConfig, load_config
|
||||
from services.shared.llm_protocol import LLMClient
|
||||
from services.shared.logging import inject_trace_context, setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_AGGREGATION,
|
||||
@@ -31,11 +33,22 @@ from services.shared.redis_keys import (
|
||||
logger = logging.getLogger("extractor_main")
|
||||
|
||||
|
||||
def _get_provider(resolved: ResolvedAgentConfig | None) -> str:
|
||||
"""Return the normalised provider string for a resolved config."""
|
||||
if resolved is None:
|
||||
return "ollama"
|
||||
return (resolved.model_provider or "").strip().lower() or "ollama"
|
||||
|
||||
|
||||
def _build_ollama_config_from_resolved(
|
||||
resolved: ResolvedAgentConfig,
|
||||
base_config: OllamaConfig,
|
||||
) -> OllamaConfig:
|
||||
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings."""
|
||||
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings.
|
||||
|
||||
Kept for backward compatibility — the factory's ``build_config_from_resolved``
|
||||
is now the primary path.
|
||||
"""
|
||||
return OllamaConfig(
|
||||
base_url=base_config.base_url,
|
||||
model=resolved.model_name,
|
||||
@@ -239,7 +252,7 @@ async def _process_macro_classification(
|
||||
*,
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
ollama: OllamaClient,
|
||||
ollama: LLMClient,
|
||||
redis_client: aioredis.Redis,
|
||||
document_id: str,
|
||||
text: str,
|
||||
@@ -258,7 +271,7 @@ async def _process_macro_classification(
|
||||
event = await classify_global_event(
|
||||
normalized_text=text,
|
||||
document_id=document_id,
|
||||
ollama_client=ollama,
|
||||
client=ollama,
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
)
|
||||
@@ -329,48 +342,69 @@ async def main() -> None:
|
||||
# Resolve extractor config from DB (active variant override + TTL cache)
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||
resolved_config: ResolvedAgentConfig | None = None
|
||||
extractor_ollama_config = config.ollama
|
||||
extractor_provider = "ollama"
|
||||
try:
|
||||
resolved_config = await resolver.resolve("document-extractor")
|
||||
if resolved_config is not None:
|
||||
extractor_ollama_config = _build_ollama_config_from_resolved(
|
||||
resolved_config, config.ollama,
|
||||
)
|
||||
extractor_provider = _get_provider(resolved_config)
|
||||
logger.info(
|
||||
"Extractor using resolved config: model=%s variant=%s",
|
||||
resolved_config.model_name, resolved_config.variant_id,
|
||||
"Extractor using resolved config: model=%s variant=%s provider=%s",
|
||||
resolved_config.model_name, resolved_config.variant_id, extractor_provider,
|
||||
)
|
||||
else:
|
||||
logger.info("No DB config for document-extractor — using env defaults")
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
|
||||
|
||||
ollama = OllamaClient(extractor_ollama_config)
|
||||
# vLLM health check at startup when provider is vllm (Requirement 7.1–7.3)
|
||||
if extractor_provider == "vllm":
|
||||
healthy = await check_vllm_health(config.vllm.base_url)
|
||||
if not healthy:
|
||||
logger.warning(
|
||||
"vLLM health check failed at startup — falling back to Ollama for extractor",
|
||||
)
|
||||
extractor_provider = "ollama"
|
||||
# Override resolved config provider so factory builds OllamaClient
|
||||
resolved_config = None
|
||||
|
||||
extractor_client: LLMClient = build_llm_client(
|
||||
resolved_config, config.ollama, config.vllm,
|
||||
)
|
||||
|
||||
# Resolve event classifier config separately (may use different model)
|
||||
classifier_resolved: ResolvedAgentConfig | None = None
|
||||
classifier_ollama_config = config.ollama
|
||||
classifier_provider = "ollama"
|
||||
try:
|
||||
classifier_resolved = await resolver.resolve("event-classifier")
|
||||
if classifier_resolved is not None:
|
||||
classifier_ollama_config = _build_ollama_config_from_resolved(
|
||||
classifier_resolved, config.ollama,
|
||||
)
|
||||
classifier_provider = _get_provider(classifier_resolved)
|
||||
logger.info(
|
||||
"Event classifier using resolved config: model=%s variant=%s",
|
||||
classifier_resolved.model_name, classifier_resolved.variant_id,
|
||||
"Event classifier using resolved config: model=%s variant=%s provider=%s",
|
||||
classifier_resolved.model_name, classifier_resolved.variant_id, classifier_provider,
|
||||
)
|
||||
else:
|
||||
logger.info("No DB config for event-classifier — using extractor config")
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
|
||||
|
||||
# Use a separate OllamaClient for the classifier if it has a different model
|
||||
classifier_ollama: OllamaClient
|
||||
if classifier_ollama_config.model != extractor_ollama_config.model:
|
||||
classifier_ollama = OllamaClient(classifier_ollama_config)
|
||||
# vLLM health check for classifier if it uses vllm and extractor didn't already check
|
||||
if classifier_provider == "vllm" and extractor_provider != "vllm":
|
||||
healthy = await check_vllm_health(config.vllm.base_url)
|
||||
if not healthy:
|
||||
logger.warning(
|
||||
"vLLM health check failed at startup — falling back to Ollama for classifier",
|
||||
)
|
||||
classifier_provider = "ollama"
|
||||
classifier_resolved = None
|
||||
|
||||
# Build classifier client — share with extractor when configs match
|
||||
classifier_client: LLMClient
|
||||
if classifier_resolved is not None or classifier_provider != extractor_provider:
|
||||
classifier_client = build_llm_client(
|
||||
classifier_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
else:
|
||||
classifier_ollama = ollama
|
||||
classifier_client = extractor_client
|
||||
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_EXTRACTION)
|
||||
@@ -441,40 +475,66 @@ async def main() -> None:
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
# Re-resolve extractor config (picks up active variant swaps)
|
||||
try:
|
||||
resolved_config = await resolver.resolve("document-extractor")
|
||||
if resolved_config is not None:
|
||||
new_ollama_cfg = _build_ollama_config_from_resolved(
|
||||
resolved_config, config.ollama,
|
||||
new_resolved = await resolver.resolve("document-extractor")
|
||||
if new_resolved is not None:
|
||||
new_provider = _get_provider(new_resolved)
|
||||
new_cfg = build_config_from_resolved(
|
||||
new_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
if new_ollama_cfg.model != ollama._config.model:
|
||||
old_provider = extractor_provider
|
||||
provider_changed = new_provider != extractor_provider
|
||||
model_changed = new_cfg.model != extractor_client._config.model
|
||||
|
||||
if provider_changed or model_changed:
|
||||
logger.info(
|
||||
"Extractor config changed: model=%s variant=%s",
|
||||
resolved_config.model_name, resolved_config.variant_id,
|
||||
"Extractor provider switch: old_provider=%s new_provider=%s "
|
||||
"model=%s variant=%s",
|
||||
old_provider, new_provider,
|
||||
new_resolved.model_name, new_resolved.variant_id,
|
||||
)
|
||||
await ollama.close()
|
||||
ollama = OllamaClient(new_ollama_cfg)
|
||||
await extractor_client.close()
|
||||
extractor_client = build_llm_client(
|
||||
new_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
extractor_provider = new_provider
|
||||
else:
|
||||
ollama._config = new_ollama_cfg
|
||||
# Same provider and model — just update config in-place
|
||||
extractor_client._config = new_cfg # type: ignore[assignment]
|
||||
resolved_config = new_resolved
|
||||
except Exception:
|
||||
logger.warning("Failed to refresh extractor config", exc_info=True)
|
||||
|
||||
# Re-resolve event classifier config
|
||||
try:
|
||||
classifier_resolved = await resolver.resolve("event-classifier")
|
||||
if classifier_resolved is not None:
|
||||
new_cls_cfg = _build_ollama_config_from_resolved(
|
||||
classifier_resolved, config.ollama,
|
||||
new_cls_resolved = await resolver.resolve("event-classifier")
|
||||
if new_cls_resolved is not None:
|
||||
new_cls_provider = _get_provider(new_cls_resolved)
|
||||
new_cls_cfg = build_config_from_resolved(
|
||||
new_cls_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
if new_cls_cfg.model != classifier_ollama._config.model:
|
||||
old_cls_provider = classifier_provider
|
||||
cls_provider_changed = new_cls_provider != classifier_provider
|
||||
cls_model_changed = new_cls_cfg.model != classifier_client._config.model
|
||||
|
||||
if cls_provider_changed or cls_model_changed:
|
||||
logger.info(
|
||||
"Event classifier config changed: model=%s variant=%s",
|
||||
classifier_resolved.model_name, classifier_resolved.variant_id,
|
||||
"Classifier provider switch: old_provider=%s new_provider=%s "
|
||||
"model=%s variant=%s",
|
||||
old_cls_provider, new_cls_provider,
|
||||
new_cls_resolved.model_name, new_cls_resolved.variant_id,
|
||||
)
|
||||
if classifier_ollama is not ollama:
|
||||
await classifier_ollama.close()
|
||||
classifier_ollama = OllamaClient(new_cls_cfg)
|
||||
elif classifier_ollama is ollama and new_cls_cfg.model != ollama._config.model:
|
||||
classifier_ollama = OllamaClient(new_cls_cfg)
|
||||
if classifier_client is not extractor_client:
|
||||
await classifier_client.close()
|
||||
classifier_client = build_llm_client(
|
||||
new_cls_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
classifier_provider = new_cls_provider
|
||||
elif classifier_client is extractor_client and new_cls_cfg.model != extractor_client._config.model:
|
||||
classifier_client = build_llm_client(
|
||||
new_cls_resolved, config.ollama, config.vllm,
|
||||
)
|
||||
classifier_provider = new_cls_provider
|
||||
classifier_resolved = new_cls_resolved
|
||||
except Exception:
|
||||
logger.warning("Failed to refresh event-classifier config", exc_info=True)
|
||||
|
||||
@@ -490,7 +550,7 @@ async def main() -> None:
|
||||
await _process_macro_classification(
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
ollama=classifier_ollama,
|
||||
ollama=classifier_client,
|
||||
redis_client=redis_client,
|
||||
document_id=document_id,
|
||||
text=text,
|
||||
@@ -529,7 +589,7 @@ async def main() -> None:
|
||||
|
||||
# Pass all tracked tickers so the model can identify any mentioned companies
|
||||
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
||||
extraction_response = await ollama.extract(
|
||||
extraction_response = await extractor_client.extract(
|
||||
extraction_text,
|
||||
document_id=document_id,
|
||||
known_tickers=all_tickers,
|
||||
|
||||
Reference in New Issue
Block a user