feat: add remote vLLM support with provider abstraction layer

- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
This commit is contained in:
Celes Renata
2026-04-23 08:17:23 +00:00
parent 63e4fb96ea
commit 117b693b19
15 changed files with 1876 additions and 77 deletions
+106 -46
View File
@@ -15,11 +15,13 @@ from services.aggregation.interpolation import (
filter_low_confidence_events,
persist_macro_impact_records,
)
from services.extractor.client import OllamaClient
from services.extractor.event_classifier import classify_global_event
from services.extractor.llm_factory import build_config_from_resolved, build_llm_client
from services.extractor.vllm_client import check_vllm_health
from services.extractor.worker import persist_extraction
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
from services.shared.config import OllamaConfig, load_config
from services.shared.llm_protocol import LLMClient
from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import (
QUEUE_AGGREGATION,
@@ -31,11 +33,22 @@ from services.shared.redis_keys import (
logger = logging.getLogger("extractor_main")
def _get_provider(resolved: ResolvedAgentConfig | None) -> str:
"""Return the normalised provider string for a resolved config."""
if resolved is None:
return "ollama"
return (resolved.model_provider or "").strip().lower() or "ollama"
def _build_ollama_config_from_resolved(
resolved: ResolvedAgentConfig,
base_config: OllamaConfig,
) -> OllamaConfig:
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings."""
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings.
Kept for backward compatibility — the factory's ``build_config_from_resolved``
is now the primary path.
"""
return OllamaConfig(
base_url=base_config.base_url,
model=resolved.model_name,
@@ -239,7 +252,7 @@ async def _process_macro_classification(
*,
pool: asyncpg.Pool,
minio_client: Minio,
ollama: OllamaClient,
ollama: LLMClient,
redis_client: aioredis.Redis,
document_id: str,
text: str,
@@ -258,7 +271,7 @@ async def _process_macro_classification(
event = await classify_global_event(
normalized_text=text,
document_id=document_id,
ollama_client=ollama,
client=ollama,
pool=pool,
minio_client=minio_client,
)
@@ -329,48 +342,69 @@ async def main() -> None:
# Resolve extractor config from DB (active variant override + TTL cache)
resolver = AgentConfigResolver(pool, ttl_seconds=60)
resolved_config: ResolvedAgentConfig | None = None
extractor_ollama_config = config.ollama
extractor_provider = "ollama"
try:
resolved_config = await resolver.resolve("document-extractor")
if resolved_config is not None:
extractor_ollama_config = _build_ollama_config_from_resolved(
resolved_config, config.ollama,
)
extractor_provider = _get_provider(resolved_config)
logger.info(
"Extractor using resolved config: model=%s variant=%s",
resolved_config.model_name, resolved_config.variant_id,
"Extractor using resolved config: model=%s variant=%s provider=%s",
resolved_config.model_name, resolved_config.variant_id, extractor_provider,
)
else:
logger.info("No DB config for document-extractor — using env defaults")
except Exception:
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
ollama = OllamaClient(extractor_ollama_config)
# vLLM health check at startup when provider is vllm (Requirement 7.17.3)
if extractor_provider == "vllm":
healthy = await check_vllm_health(config.vllm.base_url)
if not healthy:
logger.warning(
"vLLM health check failed at startup — falling back to Ollama for extractor",
)
extractor_provider = "ollama"
# Override resolved config provider so factory builds OllamaClient
resolved_config = None
extractor_client: LLMClient = build_llm_client(
resolved_config, config.ollama, config.vllm,
)
# Resolve event classifier config separately (may use different model)
classifier_resolved: ResolvedAgentConfig | None = None
classifier_ollama_config = config.ollama
classifier_provider = "ollama"
try:
classifier_resolved = await resolver.resolve("event-classifier")
if classifier_resolved is not None:
classifier_ollama_config = _build_ollama_config_from_resolved(
classifier_resolved, config.ollama,
)
classifier_provider = _get_provider(classifier_resolved)
logger.info(
"Event classifier using resolved config: model=%s variant=%s",
classifier_resolved.model_name, classifier_resolved.variant_id,
"Event classifier using resolved config: model=%s variant=%s provider=%s",
classifier_resolved.model_name, classifier_resolved.variant_id, classifier_provider,
)
else:
logger.info("No DB config for event-classifier — using extractor config")
except Exception:
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
# Use a separate OllamaClient for the classifier if it has a different model
classifier_ollama: OllamaClient
if classifier_ollama_config.model != extractor_ollama_config.model:
classifier_ollama = OllamaClient(classifier_ollama_config)
# vLLM health check for classifier if it uses vllm and extractor didn't already check
if classifier_provider == "vllm" and extractor_provider != "vllm":
healthy = await check_vllm_health(config.vllm.base_url)
if not healthy:
logger.warning(
"vLLM health check failed at startup — falling back to Ollama for classifier",
)
classifier_provider = "ollama"
classifier_resolved = None
# Build classifier client — share with extractor when configs match
classifier_client: LLMClient
if classifier_resolved is not None or classifier_provider != extractor_provider:
classifier_client = build_llm_client(
classifier_resolved, config.ollama, config.vllm,
)
else:
classifier_ollama = ollama
classifier_client = extractor_client
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION)
@@ -441,40 +475,66 @@ async def main() -> None:
company_id_map = await _build_company_id_map(pool)
# Re-resolve extractor config (picks up active variant swaps)
try:
resolved_config = await resolver.resolve("document-extractor")
if resolved_config is not None:
new_ollama_cfg = _build_ollama_config_from_resolved(
resolved_config, config.ollama,
new_resolved = await resolver.resolve("document-extractor")
if new_resolved is not None:
new_provider = _get_provider(new_resolved)
new_cfg = build_config_from_resolved(
new_resolved, config.ollama, config.vllm,
)
if new_ollama_cfg.model != ollama._config.model:
old_provider = extractor_provider
provider_changed = new_provider != extractor_provider
model_changed = new_cfg.model != extractor_client._config.model
if provider_changed or model_changed:
logger.info(
"Extractor config changed: model=%s variant=%s",
resolved_config.model_name, resolved_config.variant_id,
"Extractor provider switch: old_provider=%s new_provider=%s "
"model=%s variant=%s",
old_provider, new_provider,
new_resolved.model_name, new_resolved.variant_id,
)
await ollama.close()
ollama = OllamaClient(new_ollama_cfg)
await extractor_client.close()
extractor_client = build_llm_client(
new_resolved, config.ollama, config.vllm,
)
extractor_provider = new_provider
else:
ollama._config = new_ollama_cfg
# Same provider and model — just update config in-place
extractor_client._config = new_cfg # type: ignore[assignment]
resolved_config = new_resolved
except Exception:
logger.warning("Failed to refresh extractor config", exc_info=True)
# Re-resolve event classifier config
try:
classifier_resolved = await resolver.resolve("event-classifier")
if classifier_resolved is not None:
new_cls_cfg = _build_ollama_config_from_resolved(
classifier_resolved, config.ollama,
new_cls_resolved = await resolver.resolve("event-classifier")
if new_cls_resolved is not None:
new_cls_provider = _get_provider(new_cls_resolved)
new_cls_cfg = build_config_from_resolved(
new_cls_resolved, config.ollama, config.vllm,
)
if new_cls_cfg.model != classifier_ollama._config.model:
old_cls_provider = classifier_provider
cls_provider_changed = new_cls_provider != classifier_provider
cls_model_changed = new_cls_cfg.model != classifier_client._config.model
if cls_provider_changed or cls_model_changed:
logger.info(
"Event classifier config changed: model=%s variant=%s",
classifier_resolved.model_name, classifier_resolved.variant_id,
"Classifier provider switch: old_provider=%s new_provider=%s "
"model=%s variant=%s",
old_cls_provider, new_cls_provider,
new_cls_resolved.model_name, new_cls_resolved.variant_id,
)
if classifier_ollama is not ollama:
await classifier_ollama.close()
classifier_ollama = OllamaClient(new_cls_cfg)
elif classifier_ollama is ollama and new_cls_cfg.model != ollama._config.model:
classifier_ollama = OllamaClient(new_cls_cfg)
if classifier_client is not extractor_client:
await classifier_client.close()
classifier_client = build_llm_client(
new_cls_resolved, config.ollama, config.vllm,
)
classifier_provider = new_cls_provider
elif classifier_client is extractor_client and new_cls_cfg.model != extractor_client._config.model:
classifier_client = build_llm_client(
new_cls_resolved, config.ollama, config.vllm,
)
classifier_provider = new_cls_provider
classifier_resolved = new_cls_resolved
except Exception:
logger.warning("Failed to refresh event-classifier config", exc_info=True)
@@ -490,7 +550,7 @@ async def main() -> None:
await _process_macro_classification(
pool=pool,
minio_client=minio_client,
ollama=classifier_ollama,
ollama=classifier_client,
redis_client=redis_client,
document_id=document_id,
text=text,
@@ -529,7 +589,7 @@ async def main() -> None:
# Pass all tracked tickers so the model can identify any mentioned companies
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
extraction_response = await ollama.extract(
extraction_response = await extractor_client.extract(
extraction_text,
document_id=document_id,
known_tickers=all_tickers,