feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference - VLLMClient for OpenAI-compatible /v1/chat/completions API - LLM client factory with provider routing (ollama/vllm) - VLLMConfig with VLLM_* environment variable loading - Updated extractor worker with health check and provider switching - Updated event classifier to use LLMClient protocol - Helm values for vLLM configuration - 18 unit tests + 6 property-based tests - Full backward compatibility preserved
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
"""Event classifier module for macro news articles.
|
||||
|
||||
Classifies global/geopolitical news articles into structured GlobalEvent
|
||||
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
|
||||
existing OllamaClient for inference and retry logic.
|
||||
objects using an LLM client (Ollama or vLLM) with a dedicated prompt and
|
||||
JSON schema. Uses the LLMClient protocol for provider-agnostic inference
|
||||
and retry logic.
|
||||
|
||||
Persists classification prompts, raw outputs, and final events to MinIO
|
||||
and PostgreSQL for audit and downstream interpolation.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
||||
Requirements: 1.4, 2.1, 2.2, 2.3, 2.4, 2.5, 6.2
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -24,6 +25,8 @@ import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||||
from services.shared.config import VLLMConfig
|
||||
from services.shared.llm_protocol import LLMClient
|
||||
from services.shared.schemas import (
|
||||
EstimatedDuration,
|
||||
ImpactType,
|
||||
@@ -281,6 +284,7 @@ def _parse_classification_response(
|
||||
raw_json: str,
|
||||
document_id: str,
|
||||
model_name: str,
|
||||
provider: str = "ollama",
|
||||
) -> GlobalEvent:
|
||||
"""Parse raw Ollama JSON output into a GlobalEvent.
|
||||
|
||||
@@ -345,7 +349,7 @@ def _parse_classification_response(
|
||||
confidence=confidence,
|
||||
source_document_id=document_id,
|
||||
model_metadata=ModelMetadata(
|
||||
provider="ollama",
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
prompt_version=PROMPT_VERSION,
|
||||
schema_version=SCHEMA_VERSION,
|
||||
@@ -479,21 +483,21 @@ async def persist_global_event(
|
||||
async def classify_global_event(
|
||||
normalized_text: str,
|
||||
document_id: str,
|
||||
ollama_client: Any,
|
||||
client: LLMClient,
|
||||
*,
|
||||
pool: asyncpg.Pool | None = None,
|
||||
minio_client: Minio | None = None,
|
||||
) -> GlobalEvent:
|
||||
"""Classify a macro news article into a GlobalEvent using Ollama.
|
||||
"""Classify a macro news article into a GlobalEvent using an LLM.
|
||||
|
||||
Uses the existing OllamaClient's streaming infrastructure with a
|
||||
dedicated event classification prompt and JSON schema. Follows the
|
||||
same retry policy as document extraction.
|
||||
Uses the LLMClient protocol's call_llm() method with a dedicated
|
||||
event classification prompt and JSON schema. Follows the same retry
|
||||
policy as document extraction.
|
||||
|
||||
Resolves runtime config for the "event-classifier" agent slug from
|
||||
the database, preferring an active variant's model_name and
|
||||
system_prompt if one exists. Falls back to the OllamaClient's
|
||||
existing config if resolution fails.
|
||||
system_prompt if one exists. Falls back to the client's existing
|
||||
config if resolution fails.
|
||||
|
||||
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
||||
when the respective clients are provided.
|
||||
@@ -501,7 +505,7 @@ async def classify_global_event(
|
||||
Args:
|
||||
normalized_text: Cleaned text content of the macro article.
|
||||
document_id: UUID of the source document.
|
||||
ollama_client: An OllamaClient instance (from services.extractor.client).
|
||||
client: An LLMClient instance (OllamaClient or VLLMClient).
|
||||
pool: Optional asyncpg pool for PostgreSQL persistence.
|
||||
minio_client: Optional MinIO client for artifact persistence.
|
||||
|
||||
@@ -528,7 +532,10 @@ async def classify_global_event(
|
||||
|
||||
prompts = build_event_classification_prompt(normalized_text)
|
||||
json_schema = get_event_json_schema()
|
||||
model_name = ollama_client._config.model
|
||||
model_name = client._config.model
|
||||
|
||||
# Detect provider from client config type
|
||||
provider = "vllm" if isinstance(client._config, VLLMConfig) else "ollama"
|
||||
|
||||
# Override model_name and system_prompt from resolved config
|
||||
if resolved is not None:
|
||||
@@ -562,16 +569,16 @@ async def classify_global_event(
|
||||
except Exception:
|
||||
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
||||
|
||||
# Call Ollama using the client's internal _call_ollama method
|
||||
# Call LLM using the client's call_llm method
|
||||
# We reuse the retry logic pattern from OllamaClient.extract()
|
||||
max_retries = ollama_client._max_retries
|
||||
max_retries = client._config.max_retries
|
||||
if resolved is not None:
|
||||
max_retries = resolved.max_retries
|
||||
last_error: str | None = None
|
||||
raw_output = ""
|
||||
|
||||
for attempt_num in range(max_retries + 1):
|
||||
attempt = await ollama_client._call_ollama(prompts, json_schema)
|
||||
attempt = await client.call_llm(prompts, json_schema)
|
||||
raw_output = attempt.raw_output
|
||||
|
||||
# _call_ollama validates against the *extraction* schema, which
|
||||
@@ -581,7 +588,7 @@ async def classify_global_event(
|
||||
# Try to parse the response
|
||||
try:
|
||||
event = _parse_classification_response(
|
||||
raw_output, document_id, model_name,
|
||||
raw_output, document_id, model_name, provider=provider,
|
||||
)
|
||||
|
||||
# Persist result to MinIO
|
||||
@@ -648,10 +655,10 @@ async def classify_global_event(
|
||||
|
||||
# Retry with backoff
|
||||
if attempt_num < max_retries:
|
||||
delay = ollama_client._base_delay * (
|
||||
ollama_client._backoff_multiplier ** attempt_num
|
||||
delay = client._config.retry_base_delay * (
|
||||
client._config.retry_backoff_multiplier ** attempt_num
|
||||
)
|
||||
delay = min(delay, ollama_client._max_delay)
|
||||
delay = min(delay, client._config.retry_max_delay)
|
||||
logger.warning(
|
||||
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
||||
|
||||
Reference in New Issue
Block a user