feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference - VLLMClient for OpenAI-compatible /v1/chat/completions API - LLM client factory with provider routing (ollama/vllm) - VLLMConfig with VLLM_* environment variable loading - Updated extractor worker with health check and provider switching - Updated event classifier to use LLMClient protocol - Helm values for vLLM configuration - 18 unit tests + 6 property-based tests - Full backward compatibility preserved
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
{"specId": "a7e3f1b2-9c4d-4e8a-b5f6-d2a1c3e7f9b0", "workflowType": "requirements-first", "specType": "feature"}
|
||||||
@@ -0,0 +1,350 @@
|
|||||||
|
# Design Document: Remote vLLM Support
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This design introduces an LLM provider abstraction layer into Stonks Oracle so that both the existing Ollama backend and a new remote vLLM backend can be used interchangeably for document extraction and event classification. The vLLM server at `http://192.168.42.254:8000` runs `RedHatAI/Qwen3.6-35B-A3B-NVFP4` on an NVIDIA RTX 5090 with tensor parallelism and exposes an OpenAI-compatible `/v1/chat/completions` API.
|
||||||
|
|
||||||
|
The design preserves full backward compatibility — existing Ollama deployments work without any configuration changes. Provider selection is driven by the existing `model_provider` column in the `ai_agents` and `agent_variants` database tables, requiring no new migrations.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
subgraph "Extractor Worker"
|
||||||
|
MAIN[main.py]
|
||||||
|
FACTORY[LLMClientFactory]
|
||||||
|
EXTRACT[Extraction Pipeline]
|
||||||
|
CLASSIFY[Event Classification Pipeline]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "Provider Abstraction"
|
||||||
|
PROTO[LLMClient Protocol]
|
||||||
|
OLLAMA_IMPL[OllamaClient]
|
||||||
|
VLLM_IMPL[VLLMClient]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "Configuration"
|
||||||
|
RESOLVER[AgentConfigResolver]
|
||||||
|
OLLAMA_CFG[OllamaConfig]
|
||||||
|
VLLM_CFG[VLLMConfig]
|
||||||
|
APP_CFG[AppConfig]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "External Services"
|
||||||
|
OLLAMA_SRV[Ollama Server<br/>:11434/api/chat]
|
||||||
|
VLLM_SRV[vLLM Server<br/>:8000/v1/chat/completions]
|
||||||
|
end
|
||||||
|
|
||||||
|
MAIN --> FACTORY
|
||||||
|
FACTORY --> PROTO
|
||||||
|
PROTO --> OLLAMA_IMPL
|
||||||
|
PROTO --> VLLM_IMPL
|
||||||
|
EXTRACT --> PROTO
|
||||||
|
CLASSIFY --> PROTO
|
||||||
|
|
||||||
|
RESOLVER --> FACTORY
|
||||||
|
OLLAMA_CFG --> FACTORY
|
||||||
|
VLLM_CFG --> FACTORY
|
||||||
|
APP_CFG --> OLLAMA_CFG
|
||||||
|
APP_CFG --> VLLM_CFG
|
||||||
|
|
||||||
|
OLLAMA_IMPL --> OLLAMA_SRV
|
||||||
|
VLLM_IMPL --> VLLM_SRV
|
||||||
|
```
|
||||||
|
|
||||||
|
The key architectural decision is to use a Python `Protocol` (structural typing) rather than an ABC for the LLM client interface. This allows the existing `OllamaClient` to satisfy the protocol without inheritance changes, maintaining backward compatibility. The `VLLMClient` is a new class that also satisfies the protocol.
|
||||||
|
|
||||||
|
A factory function in `services/extractor/llm_factory.py` takes a `ResolvedAgentConfig` and the base configs, returning the appropriate client. The extractor worker (`main.py`) uses this factory instead of directly constructing `OllamaClient`.
|
||||||
|
|
||||||
|
## Components and Interfaces
|
||||||
|
|
||||||
|
### 1. LLM Client Protocol (`services/shared/llm_protocol.py`)
|
||||||
|
|
||||||
|
A `typing.Protocol` defining the contract both clients must satisfy:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class LLMClient(Protocol):
|
||||||
|
async def call_llm(
|
||||||
|
self,
|
||||||
|
prompts: dict[str, str],
|
||||||
|
json_schema: dict[str, object],
|
||||||
|
document_text: str = "",
|
||||||
|
) -> "ExtractionAttempt": ...
|
||||||
|
|
||||||
|
async def close(self) -> None: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
The `call_llm` method signature matches the existing `OllamaClient._call_ollama()` parameters and return type. The `OllamaClient` gains a public `call_llm` method that delegates to `_call_ollama()`, preserving the private method for internal backward compatibility.
|
||||||
|
|
||||||
|
### 2. VLLMClient (`services/extractor/vllm_client.py`)
|
||||||
|
|
||||||
|
New client implementing the `LLMClient` protocol for the OpenAI-compatible API:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class VLLMClient:
|
||||||
|
_config: VLLMConfig
|
||||||
|
_http: httpx.AsyncClient
|
||||||
|
_owns_client: bool
|
||||||
|
|
||||||
|
async def call_llm(
|
||||||
|
self,
|
||||||
|
prompts: dict[str, str],
|
||||||
|
json_schema: dict[str, object],
|
||||||
|
document_text: str = "",
|
||||||
|
) -> ExtractionAttempt: ...
|
||||||
|
|
||||||
|
async def close(self) -> None: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
**Request format** (OpenAI-compatible):
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "RedHatAI/Qwen3.6-35B-A3B-NVFP4",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "..."},
|
||||||
|
{"role": "user", "content": "..."}
|
||||||
|
],
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"response_format": {"type": "json_object"}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response parsing**: Extracts `choices[0].message.content`, then applies the same `_strip_markdown_fences()` and `_repair_json()` pipeline as `OllamaClient`.
|
||||||
|
|
||||||
|
**Error handling**: Maps HTTP errors to the same string format as `OllamaClient` (`timeout`, `http_{code}`, `connection_error: {details}`, `empty_model_response`), so the existing `_is_retryable()` function works without modification.
|
||||||
|
|
||||||
|
**Key differences from OllamaClient**:
|
||||||
|
- Endpoint: `/v1/chat/completions` instead of `/api/chat`
|
||||||
|
- No `think: false`, `stream: false`, or `options` block
|
||||||
|
- Uses `max_tokens` instead of `options.num_predict`
|
||||||
|
- Uses `response_format: {"type": "json_object"}` for structured output
|
||||||
|
- Supports `temperature` parameter (Ollama uses model defaults)
|
||||||
|
- Response in `choices[0].message.content` instead of `message.content`
|
||||||
|
|
||||||
|
### 3. VLLMConfig (`services/shared/config.py`)
|
||||||
|
|
||||||
|
New dataclass alongside `OllamaConfig`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass
|
||||||
|
class VLLMConfig:
|
||||||
|
base_url: str = "http://192.168.42.254:8000"
|
||||||
|
model: str = "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||||||
|
timeout: int = 120
|
||||||
|
max_retries: int = 2
|
||||||
|
retry_base_delay: float = 1.0
|
||||||
|
retry_max_delay: float = 10.0
|
||||||
|
retry_backoff_multiplier: float = 2.0
|
||||||
|
max_tokens: int = 32768
|
||||||
|
temperature: float = 0.7
|
||||||
|
api_key: str = "" # Optional, for authenticated vLLM deployments
|
||||||
|
```
|
||||||
|
|
||||||
|
Loaded from `VLLM_*` environment variables in `load_config()`. Added to `AppConfig` as `vllm: VLLMConfig`.
|
||||||
|
|
||||||
|
### 4. LLM Client Factory (`services/extractor/llm_factory.py`)
|
||||||
|
|
||||||
|
Factory function that replaces the hardcoded `OllamaClient` construction:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def build_llm_client(
|
||||||
|
resolved: ResolvedAgentConfig | None,
|
||||||
|
ollama_config: OllamaConfig,
|
||||||
|
vllm_config: VLLMConfig,
|
||||||
|
http_client: httpx.AsyncClient | None = None,
|
||||||
|
) -> LLMClient:
|
||||||
|
"""Return the appropriate LLM client based on resolved provider."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def build_config_from_resolved(
|
||||||
|
resolved: ResolvedAgentConfig,
|
||||||
|
base_ollama: OllamaConfig,
|
||||||
|
base_vllm: VLLMConfig,
|
||||||
|
) -> OllamaConfig | VLLMConfig:
|
||||||
|
"""Build provider-specific config from resolved agent config."""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Provider routing logic:
|
||||||
|
1. If `resolved` is `None` or `resolved.model_provider` is `"ollama"` or empty → `OllamaClient`
|
||||||
|
2. If `resolved.model_provider` is `"vllm"` → `VLLMClient`
|
||||||
|
3. Unknown provider → log warning, fall back to `OllamaClient`
|
||||||
|
|
||||||
|
### 5. Updated Extractor Worker (`services/extractor/main.py`)
|
||||||
|
|
||||||
|
Changes to `main()`:
|
||||||
|
- Replace `_build_ollama_config_from_resolved()` with `build_llm_client()` from the factory
|
||||||
|
- Store clients as `LLMClient` type instead of `OllamaClient`
|
||||||
|
- On config refresh (every 100 jobs), detect provider changes and swap clients
|
||||||
|
- Log provider switches at INFO level
|
||||||
|
|
||||||
|
Changes to `_process_macro_classification()`:
|
||||||
|
- Accept `LLMClient` instead of `OllamaClient` for the classifier parameter
|
||||||
|
|
||||||
|
### 6. Updated OllamaClient (`services/extractor/client.py`)
|
||||||
|
|
||||||
|
Minimal changes to satisfy the protocol:
|
||||||
|
- Add public `call_llm()` method that delegates to `_call_ollama()`
|
||||||
|
- Keep `_call_ollama()` as-is for backward compatibility
|
||||||
|
- The `extract()` method continues to call `_call_ollama()` internally
|
||||||
|
|
||||||
|
### 7. Updated Event Classifier (`services/extractor/event_classifier.py`)
|
||||||
|
|
||||||
|
Changes to `classify_global_event()`:
|
||||||
|
- Accept `LLMClient` instead of `Any` for the `ollama_client` parameter
|
||||||
|
- Call `client.call_llm()` instead of `ollama_client._call_ollama()`
|
||||||
|
- Set `ModelMetadata.provider` based on the actual client type (inspect `_config` or pass provider string)
|
||||||
|
|
||||||
|
### 8. Helm Values (`infra/helm/stonks-oracle/values.yaml`)
|
||||||
|
|
||||||
|
New config entries:
|
||||||
|
```yaml
|
||||||
|
config:
|
||||||
|
VLLM_BASE_URL: "http://192.168.42.254:8000"
|
||||||
|
VLLM_MODEL: "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||||||
|
VLLM_TIMEOUT: "120"
|
||||||
|
VLLM_MAX_RETRIES: "2"
|
||||||
|
VLLM_TEMPERATURE: "0.7"
|
||||||
|
VLLM_API_KEY: ""
|
||||||
|
```
|
||||||
|
|
||||||
|
### 9. Health Check (`services/extractor/vllm_client.py`)
|
||||||
|
|
||||||
|
Startup validation function:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def check_vllm_health(base_url: str, timeout: float = 10.0) -> bool:
|
||||||
|
"""GET {base_url}/v1/models to verify vLLM is reachable."""
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Called from `main()` when the resolved or default config specifies vLLM. On failure, logs WARNING and falls back to Ollama. On success, logs INFO with server URL and model list.
|
||||||
|
|
||||||
|
## Data Models
|
||||||
|
|
||||||
|
### VLLMConfig Dataclass
|
||||||
|
|
||||||
|
| Field | Type | Default | Env Var |
|
||||||
|
|-------|------|---------|---------|
|
||||||
|
| `base_url` | `str` | `http://192.168.42.254:8000` | `VLLM_BASE_URL` |
|
||||||
|
| `model` | `str` | `RedHatAI/Qwen3.6-35B-A3B-NVFP4` | `VLLM_MODEL` |
|
||||||
|
| `timeout` | `int` | `120` | `VLLM_TIMEOUT` |
|
||||||
|
| `max_retries` | `int` | `2` | `VLLM_MAX_RETRIES` |
|
||||||
|
| `retry_base_delay` | `float` | `1.0` | `VLLM_RETRY_BASE_DELAY` |
|
||||||
|
| `retry_max_delay` | `float` | `10.0` | `VLLM_RETRY_MAX_DELAY` |
|
||||||
|
| `retry_backoff_multiplier` | `float` | `2.0` | `VLLM_RETRY_BACKOFF_MULTIPLIER` |
|
||||||
|
| `max_tokens` | `int` | `32768` | `VLLM_MAX_TOKENS` |
|
||||||
|
| `temperature` | `float` | `0.7` | `VLLM_TEMPERATURE` |
|
||||||
|
| `api_key` | `str` | `""` | `VLLM_API_KEY` |
|
||||||
|
|
||||||
|
### ExtractionAttempt (unchanged)
|
||||||
|
|
||||||
|
The existing `ExtractionAttempt` dataclass is reused as-is for both providers. No changes needed.
|
||||||
|
|
||||||
|
### ModelMetadata (unchanged structure, new values)
|
||||||
|
|
||||||
|
The `provider` field now accepts `"vllm"` in addition to `"ollama"`. No schema change needed.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
### Error String Format Parity
|
||||||
|
|
||||||
|
Both clients produce identical error string formats so `_is_retryable()` works unchanged:
|
||||||
|
|
||||||
|
| Condition | Error String | Retryable |
|
||||||
|
|-----------|-------------|-----------|
|
||||||
|
| HTTP timeout | `timeout` | Yes |
|
||||||
|
| HTTP 400/401/403/404/422 | `http_{code}` | No |
|
||||||
|
| HTTP 500/502/503/429 | `http_{code}` | Yes |
|
||||||
|
| Connection refused/reset | `connection_error: {details}` | Yes |
|
||||||
|
| Empty response body | `empty_model_response` | Yes |
|
||||||
|
| Invalid JSON in response | `invalid_response_json` | Yes |
|
||||||
|
|
||||||
|
### Health Check Failure
|
||||||
|
|
||||||
|
If the vLLM health check fails at startup:
|
||||||
|
1. Log WARNING with the error details
|
||||||
|
2. Fall back to `OllamaClient` using `OllamaConfig`
|
||||||
|
3. Continue operation — the system degrades gracefully rather than crashing
|
||||||
|
|
||||||
|
### Provider Switch During Refresh
|
||||||
|
|
||||||
|
When the config refresh (every 100 jobs) detects a provider change:
|
||||||
|
1. Close the old client (`await old_client.close()`)
|
||||||
|
2. Construct the new client via the factory
|
||||||
|
3. Log the switch at INFO level
|
||||||
|
4. If new client construction fails, keep the old client and log ERROR
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
### Property-Based Tests (`tests/test_pbt_llm_provider.py`)
|
||||||
|
|
||||||
|
Property-based tests using Hypothesis to verify the provider abstraction:
|
||||||
|
|
||||||
|
**P1: Provider factory routing property** (Req 3.4, 3.5, 9.5)
|
||||||
|
For all `model_provider` values in `{"ollama", "vllm", "", None}`, the factory returns the correct client type. For `"ollama"`, empty, or `None`, returns `OllamaClient`. For `"vllm"`, returns `VLLMClient`.
|
||||||
|
|
||||||
|
**P2: Error string format consistency property** (Req 5.6)
|
||||||
|
For all HTTP status codes (100-599), both `OllamaClient` and `VLLMClient` produce error strings in the same format (`http_{code}`), and `_is_retryable()` returns the same result for both.
|
||||||
|
|
||||||
|
**P3: VLLMClient request payload structure property** (Req 2.1, 8.1)
|
||||||
|
For all generated prompt dicts (system + user messages of arbitrary text), the VLLMClient produces a request payload that: contains `model`, `messages`, `max_tokens`, `temperature`; does NOT contain `think`, `stream`, `options`, `num_ctx`, `num_predict`.
|
||||||
|
|
||||||
|
**P4: JSON repair idempotence property** (Req 2.4)
|
||||||
|
For all valid JSON strings, `_repair_json(json_str)` returns a string that `json.loads()` can parse, and `_repair_json(_repair_json(json_str)) == _repair_json(json_str)` (idempotence).
|
||||||
|
|
||||||
|
**P5: Markdown fence stripping round-trip property** (Req 2.3)
|
||||||
|
For all strings `s`, `_strip_markdown_fences(f"```json\n{s}\n```")` returns `s` (stripped), and `_strip_markdown_fences(s)` returns `s` when no fences are present (identity).
|
||||||
|
|
||||||
|
**P6: VLLMConfig default construction property** (Req 3.1)
|
||||||
|
For all VLLMConfig instances constructed with default values, `base_url` is non-empty, `timeout > 0`, `max_retries >= 0`, `temperature` is between 0.0 and 2.0, and `max_tokens > 0`.
|
||||||
|
|
||||||
|
### Unit Tests (`tests/test_vllm_client.py`)
|
||||||
|
|
||||||
|
Example-based tests for specific behaviors:
|
||||||
|
|
||||||
|
- VLLMClient sends correct payload to `/v1/chat/completions` (mock httpx)
|
||||||
|
- VLLMClient extracts content from `choices[0].message.content`
|
||||||
|
- VLLMClient handles empty choices array → `empty_model_response`
|
||||||
|
- VLLMClient handles timeout → `timeout` error
|
||||||
|
- VLLMClient handles HTTP 500 → `http_500` error, retryable
|
||||||
|
- VLLMClient handles HTTP 400 → `http_400` error, non-retryable
|
||||||
|
- VLLMClient handles connection refused → `connection_error: ...`
|
||||||
|
- VLLMClient applies markdown fence stripping
|
||||||
|
- VLLMClient applies JSON repair
|
||||||
|
- VLLMClient includes temperature in payload
|
||||||
|
- VLLMClient includes `response_format` in payload
|
||||||
|
- Health check success logs INFO
|
||||||
|
- Health check failure logs WARNING and returns False
|
||||||
|
- Factory returns OllamaClient for provider="ollama"
|
||||||
|
- Factory returns VLLMClient for provider="vllm"
|
||||||
|
- Factory returns OllamaClient for provider="" (default)
|
||||||
|
- Factory returns OllamaClient for unknown provider with warning
|
||||||
|
- VLLMConfig loads from environment variables
|
||||||
|
- AppConfig includes vllm field with defaults
|
||||||
|
- OllamaClient.call_llm() delegates to _call_ollama()
|
||||||
|
|
||||||
|
### Existing Tests (unchanged)
|
||||||
|
|
||||||
|
- `tests/test_ollama_client.py` — continues to pass without modification
|
||||||
|
- All other existing test files — unaffected
|
||||||
|
|
||||||
|
## File Changes Summary
|
||||||
|
|
||||||
|
| File | Change Type | Description |
|
||||||
|
|------|-------------|-------------|
|
||||||
|
| `services/shared/llm_protocol.py` | **New** | `LLMClient` Protocol definition |
|
||||||
|
| `services/extractor/vllm_client.py` | **New** | `VLLMClient` implementation + health check |
|
||||||
|
| `services/extractor/llm_factory.py` | **New** | Factory function for provider routing |
|
||||||
|
| `services/shared/config.py` | **Modified** | Add `VLLMConfig`, update `AppConfig`, update `load_config()` |
|
||||||
|
| `services/extractor/client.py` | **Modified** | Add `call_llm()` public method to `OllamaClient` |
|
||||||
|
| `services/extractor/event_classifier.py` | **Modified** | Use `call_llm()` instead of `_call_ollama()`, accept `LLMClient` type |
|
||||||
|
| `services/extractor/main.py` | **Modified** | Use factory, support provider switching, health check |
|
||||||
|
| `infra/helm/stonks-oracle/values.yaml` | **Modified** | Add `VLLM_*` config entries |
|
||||||
|
| `tests/test_pbt_llm_provider.py` | **New** | Property-based tests for provider abstraction |
|
||||||
|
| `tests/test_vllm_client.py` | **New** | Unit tests for VLLMClient and factory |
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
# Requirements Document
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
Add remote vLLM support to the Stonks Oracle platform. The system currently uses Ollama exclusively for LLM inference via the `/api/chat` endpoint. A remote vLLM server running `RedHatAI/Qwen3.6-35B-A3B-NVFP4` on a 5090 GPU with tensor parallelism is available at `http://192.168.42.254:8000` and exposes an OpenAI-compatible `/v1/chat/completions` API. This feature introduces a provider abstraction layer so that both Ollama and vLLM backends can be used interchangeably, selected per-agent via the existing `model_provider` database column and environment variable configuration. The abstraction preserves all existing behavior (retry logic, JSON repair, audit trail, backoff, context window override) while adapting to the differences between the two API protocols.
|
||||||
|
|
||||||
|
## Glossary
|
||||||
|
|
||||||
|
- **LLM_Client**: An abstract interface defining the contract for sending chat completion requests to any LLM backend. Concrete implementations exist for Ollama and vLLM.
|
||||||
|
- **Ollama_Backend**: The existing Ollama inference server at `ollama.ollama-service.svc.cluster.local:11434` (cluster) or `http://10.1.1.12:2701` (external), using the `/api/chat` endpoint with Ollama-specific payload fields (`think`, `options.num_ctx`, `options.num_predict`).
|
||||||
|
- **VLLM_Backend**: A remote vLLM inference server at `http://192.168.42.254:8000` exposing the OpenAI-compatible `/v1/chat/completions` endpoint. Runs `RedHatAI/Qwen3.6-35B-A3B-NVFP4` on a 5090 GPU with tensor parallelism.
|
||||||
|
- **Provider**: A string identifier (`ollama` or `vllm`) that determines which LLM_Client implementation is used for a given agent. Stored in the `model_provider` column of `ai_agents` and `agent_variants` tables.
|
||||||
|
- **LLM_Config**: A provider-agnostic configuration dataclass containing connection and inference parameters (base_url, model, timeout, retries, max_tokens, context_window) used to construct an LLM_Client.
|
||||||
|
- **Extraction_Pipeline**: The document intelligence extraction workflow in `services/extractor/client.py` that sends documents to an LLM and parses structured JSON responses.
|
||||||
|
- **Event_Classification_Pipeline**: The macro event classification workflow in `services/extractor/event_classifier.py` that classifies global news articles via an LLM.
|
||||||
|
- **Agent_Config_Resolver**: The `AgentConfigResolver` in `services/shared/agent_config.py` that resolves runtime configuration from the `ai_agents` and `agent_variants` database tables, including the `model_provider` field.
|
||||||
|
- **OpenAI_Chat_Format**: The request/response format used by `/v1/chat/completions` — messages array with role/content, `max_tokens`, `temperature`, and response in `choices[0].message.content`.
|
||||||
|
- **JSON_Repair**: The existing `json-repair` library usage that fixes malformed JSON from model output, applied regardless of provider.
|
||||||
|
- **Model_Metadata**: The `ModelMetadata` Pydantic model in `services/shared/schemas.py` that tracks `provider`, `model_name`, `prompt_version`, and `schema_version` for audit.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
### Requirement 1: Provider Abstraction Layer
|
||||||
|
|
||||||
|
**User Story:** As a developer, I want a provider abstraction layer that decouples LLM inference from any specific backend, so that the extraction and classification pipelines can use either Ollama or vLLM without code changes in the calling services.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. THE LLM_Client interface SHALL define an async method that accepts a messages list (system and user prompts), a JSON schema hint, and optional document text, and returns an attempt result containing raw output, validation report, error string, duration, and model name.
|
||||||
|
2. THE LLM_Client interface SHALL define an async `close` method for releasing underlying HTTP resources.
|
||||||
|
3. WHEN the Extraction_Pipeline calls the LLM, THE Extraction_Pipeline SHALL use the LLM_Client interface instead of calling Ollama-specific endpoints directly.
|
||||||
|
4. WHEN the Event_Classification_Pipeline calls the LLM, THE Event_Classification_Pipeline SHALL use the LLM_Client interface instead of calling `_call_ollama()` directly.
|
||||||
|
5. THE Ollama_Backend implementation of LLM_Client SHALL preserve the existing `/api/chat` payload structure including `think: false`, `stream: false`, `options.num_predict`, and `options.num_ctx`.
|
||||||
|
6. THE VLLM_Backend implementation of LLM_Client SHALL send requests to `/v1/chat/completions` using the OpenAI_Chat_Format with `model`, `messages`, `max_tokens`, and `temperature` fields.
|
||||||
|
7. FOR ALL valid prompt inputs, sending a prompt through the Ollama_Backend and parsing the response SHALL produce the same ExtractionAttempt structure as the current `_call_ollama()` method (round-trip equivalence with existing behavior).
|
||||||
|
|
||||||
|
### Requirement 2: vLLM Client Implementation
|
||||||
|
|
||||||
|
**User Story:** As a developer, I want a vLLM client that communicates with the remote vLLM server using the OpenAI-compatible API, so that the platform can leverage the 5090 GPU for inference.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. THE VLLM_Backend SHALL send POST requests to `{base_url}/v1/chat/completions` with a JSON payload containing `model`, `messages` (array of role/content objects), `max_tokens`, and `temperature`.
|
||||||
|
2. THE VLLM_Backend SHALL extract the response content from `choices[0].message.content` in the OpenAI-compatible response format.
|
||||||
|
3. THE VLLM_Backend SHALL apply the same markdown fence stripping logic as the Ollama_Backend to handle model output wrapped in ```json ... ``` blocks.
|
||||||
|
4. THE VLLM_Backend SHALL apply the same JSON_Repair logic as the Ollama_Backend to fix malformed JSON in model output.
|
||||||
|
5. WHEN the vLLM server returns an HTTP timeout, THE VLLM_Backend SHALL report the error as `timeout` in the attempt result, consistent with the Ollama_Backend error format.
|
||||||
|
6. WHEN the vLLM server returns an HTTP error status, THE VLLM_Backend SHALL report the error as `http_{status_code}` in the attempt result, consistent with the Ollama_Backend error format.
|
||||||
|
7. WHEN the vLLM server returns an empty `choices` array or missing `content`, THE VLLM_Backend SHALL report the error as `empty_model_response`.
|
||||||
|
8. IF the vLLM server is unreachable, THEN THE VLLM_Backend SHALL report the error as `connection_error: {details}`, consistent with the Ollama_Backend error format.
|
||||||
|
9. THE VLLM_Backend SHALL use the same `httpx.AsyncClient` timeout configuration as the Ollama_Backend, derived from the LLM_Config timeout value.
|
||||||
|
10. THE VLLM_Backend SHALL support an optional `temperature` parameter from the resolved agent config, defaulting to 0.7 when not specified.
|
||||||
|
|
||||||
|
### Requirement 3: Provider-Aware Configuration
|
||||||
|
|
||||||
|
**User Story:** As an operator, I want to configure the vLLM backend via environment variables and database agent config, so that I can switch providers without code changes.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. THE Configuration SHALL include a `VLLMConfig` dataclass with fields: `base_url` (default `http://192.168.42.254:8000`), `model` (default `RedHatAI/Qwen3.6-35B-A3B-NVFP4`), `timeout` (default 120), `max_retries` (default 2), `retry_base_delay`, `retry_max_delay`, `retry_backoff_multiplier`, `max_tokens` (default 32768), and `temperature` (default 0.7).
|
||||||
|
2. THE Configuration SHALL load VLLMConfig values from environment variables prefixed with `VLLM_` (e.g., `VLLM_BASE_URL`, `VLLM_MODEL`, `VLLM_TIMEOUT`), following the same pattern as OllamaConfig.
|
||||||
|
3. THE AppConfig dataclass SHALL include a `vllm` field of type VLLMConfig alongside the existing `ollama` field.
|
||||||
|
4. WHEN the Agent_Config_Resolver resolves a `model_provider` value of `vllm`, THE service SHALL use the VLLMConfig base_url and construct a VLLM_Backend client instead of an Ollama_Backend client.
|
||||||
|
5. WHEN the Agent_Config_Resolver resolves a `model_provider` value of `ollama` or when no `model_provider` is specified, THE service SHALL continue to use the OllamaConfig and Ollama_Backend client as the default.
|
||||||
|
6. THE `_build_ollama_config_from_resolved` function in `services/extractor/main.py` SHALL be generalized to a provider-aware factory that returns the appropriate config and client type based on the resolved `model_provider`.
|
||||||
|
|
||||||
|
### Requirement 4: Provider Selection in Extractor Worker
|
||||||
|
|
||||||
|
**User Story:** As a developer, I want the extractor worker to select the correct LLM client based on the resolved agent config provider, so that each agent can independently use Ollama or vLLM.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. WHEN the extractor worker starts, THE worker SHALL construct the default LLM_Client based on the environment variable configuration (defaulting to Ollama_Backend).
|
||||||
|
2. WHEN the Agent_Config_Resolver returns a resolved config with `model_provider = "vllm"` for the `document-extractor` slug, THE worker SHALL construct a VLLM_Backend client using the VLLMConfig base_url and the resolved model_name.
|
||||||
|
3. WHEN the Agent_Config_Resolver returns a resolved config with `model_provider = "vllm"` for the `event-classifier` slug, THE worker SHALL construct a VLLM_Backend client for the event classification pipeline.
|
||||||
|
4. WHEN the resolved config changes provider during a config refresh cycle (every 100 jobs), THE worker SHALL close the old LLM_Client and construct a new one matching the updated provider.
|
||||||
|
5. WHEN the resolved config changes from `ollama` to `vllm` or vice versa, THE worker SHALL log the provider switch at INFO level including the old and new provider, model name, and variant ID.
|
||||||
|
|
||||||
|
### Requirement 5: Retry and Error Handling Parity
|
||||||
|
|
||||||
|
**User Story:** As a developer, I want the vLLM client to use the same retry logic, backoff strategy, and error classification as the Ollama client, so that reliability behavior is consistent across providers.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. THE VLLM_Backend SHALL use the same exponential backoff computation as the Ollama_Backend, using `retry_base_delay`, `retry_max_delay`, and `retry_backoff_multiplier` from the LLM_Config.
|
||||||
|
2. THE VLLM_Backend SHALL classify HTTP 400, 401, 403, 404, and 422 errors as non-retryable, consistent with the Ollama_Backend.
|
||||||
|
3. THE VLLM_Backend SHALL classify HTTP 500, 502, 503, 429, timeout, and connection errors as retryable, consistent with the Ollama_Backend.
|
||||||
|
4. WHEN the VLLM_Backend encounters a retryable error, THE Extraction_Pipeline SHALL retry up to `max_retries` times with exponential backoff, preserving each attempt in the audit trail.
|
||||||
|
5. WHEN the VLLM_Backend encounters a non-retryable error, THE Extraction_Pipeline SHALL stop retries immediately and record the attempt as non-retryable.
|
||||||
|
6. FOR ALL error types, the VLLM_Backend error string format SHALL match the Ollama_Backend error string format so that `_is_retryable()` works without modification.
|
||||||
|
|
||||||
|
### Requirement 6: Audit Trail and Model Metadata
|
||||||
|
|
||||||
|
**User Story:** As a developer, I want the audit trail and model metadata to correctly reflect which provider and model were used for each extraction, so that I can trace results back to the specific backend.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. WHEN the VLLM_Backend completes an extraction attempt, THE attempt record SHALL include the vLLM model name in the `model` field.
|
||||||
|
2. WHEN an extraction or classification succeeds via the VLLM_Backend, THE Model_Metadata in the result SHALL have `provider` set to `"vllm"` and `model_name` set to the vLLM model identifier.
|
||||||
|
3. WHEN the `agent_performance_log` records an invocation that used the VLLM_Backend, THE log entry SHALL be attributed to the correct agent_id and variant_id, consistent with Ollama_Backend logging.
|
||||||
|
4. THE MinIO prompt and result artifacts persisted by the Event_Classification_Pipeline SHALL include the provider name and model name in the stored JSON, regardless of which backend was used.
|
||||||
|
|
||||||
|
### Requirement 7: Health Check and Connectivity Validation
|
||||||
|
|
||||||
|
**User Story:** As an operator, I want the system to validate connectivity to the vLLM server at startup, so that misconfiguration is detected early rather than failing silently on the first inference request.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. WHEN the extractor worker starts and the resolved or default config specifies `model_provider = "vllm"`, THE worker SHALL send a GET request to `{vllm_base_url}/v1/models` to verify the vLLM server is reachable.
|
||||||
|
2. IF the vLLM health check fails at startup, THEN THE worker SHALL log a WARNING and fall back to the Ollama_Backend, continuing operation with degraded capability.
|
||||||
|
3. IF the vLLM health check succeeds, THEN THE worker SHALL log an INFO message confirming the vLLM connection including the server URL and available model name.
|
||||||
|
4. THE health check SHALL use a timeout of 10 seconds to avoid blocking worker startup on an unresponsive server.
|
||||||
|
|
||||||
|
### Requirement 8: Context Window and Token Handling for vLLM
|
||||||
|
|
||||||
|
**User Story:** As a developer, I want the vLLM client to handle context window and token limits appropriately for the vLLM API, so that large documents are processed correctly on the remote GPU.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. WHEN the resolved agent config specifies a non-zero `context_window`, THE VLLM_Backend SHALL omit the `num_ctx` Ollama-specific option and instead rely on the vLLM server's model configuration for context window sizing.
|
||||||
|
2. THE VLLM_Backend SHALL pass `max_tokens` in the OpenAI-compatible request payload to control the maximum number of output tokens generated.
|
||||||
|
3. WHEN the resolved agent config specifies a non-zero `input_token_limit`, THE Extraction_Pipeline SHALL truncate the input text before sending it to the VLLM_Backend, using the same truncation logic as for the Ollama_Backend.
|
||||||
|
4. WHEN the resolved agent config specifies a non-zero `token_budget`, THE worker SHALL enforce the same hourly token budget check for vLLM invocations as for Ollama invocations.
|
||||||
|
|
||||||
|
### Requirement 9: Backward Compatibility
|
||||||
|
|
||||||
|
**User Story:** As a developer, I want the vLLM integration to be fully backward compatible, so that existing Ollama-based deployments continue to work without any configuration changes.
|
||||||
|
|
||||||
|
#### Acceptance Criteria
|
||||||
|
|
||||||
|
1. WHEN no `VLLM_BASE_URL` environment variable is set and no agent config specifies `model_provider = "vllm"`, THE system SHALL behave identically to the current Ollama-only implementation.
|
||||||
|
2. THE existing `OllamaConfig` dataclass and its environment variable loading SHALL remain unchanged.
|
||||||
|
3. THE existing `OllamaClient` class SHALL continue to function for Ollama-specific usage, with the LLM_Client interface added as a compatible layer on top.
|
||||||
|
4. THE existing test suite in `tests/test_ollama_client.py` SHALL continue to pass without modification.
|
||||||
|
5. WHEN the `model_provider` column in `ai_agents` or `agent_variants` contains `"ollama"` or NULL, THE system SHALL use the Ollama_Backend, preserving current behavior.
|
||||||
|
6. THE database migration for this feature SHALL NOT alter existing table structures; it SHALL only add new columns or tables if needed.
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
# Tasks
|
||||||
|
|
||||||
|
## Task 1: LLM Client Protocol and VLLMConfig
|
||||||
|
|
||||||
|
- [x] 1.1 Create `services/shared/llm_protocol.py` with `LLMClient` Protocol defining `call_llm(prompts, json_schema, document_text) -> ExtractionAttempt` and `close()` methods
|
||||||
|
- [x] 1.2 Add `VLLMConfig` dataclass to `services/shared/config.py` with fields: `base_url`, `model`, `timeout`, `max_retries`, `retry_base_delay`, `retry_max_delay`, `retry_backoff_multiplier`, `max_tokens`, `temperature`, `api_key`
|
||||||
|
- [x] 1.3 Add `vllm: VLLMConfig` field to `AppConfig` dataclass
|
||||||
|
- [x] 1.4 Add `VLLM_*` environment variable loading to `load_config()` function
|
||||||
|
- [x] 1.5 Add public `call_llm()` method to `OllamaClient` in `services/extractor/client.py` that delegates to `_call_ollama()`
|
||||||
|
|
||||||
|
## Task 2: VLLMClient Implementation
|
||||||
|
|
||||||
|
- [x] 2.1 Create `services/extractor/vllm_client.py` with `VLLMClient` class that satisfies the `LLMClient` protocol
|
||||||
|
- [x] 2.2 Implement `call_llm()` method that sends POST to `/v1/chat/completions` with OpenAI-compatible payload (`model`, `messages`, `max_tokens`, `temperature`, `response_format`)
|
||||||
|
- [x] 2.3 Implement response parsing: extract content from `choices[0].message.content`, apply `_strip_markdown_fences()` and `_repair_json()`
|
||||||
|
- [x] 2.4 Implement error handling: map timeout → `timeout`, HTTP errors → `http_{code}`, connection errors → `connection_error: {details}`, empty response → `empty_model_response`
|
||||||
|
- [x] 2.5 Implement `close()` method to release the underlying `httpx.AsyncClient`
|
||||||
|
- [x] 2.6 Implement `check_vllm_health(base_url, timeout=10.0)` async function that GETs `/v1/models` and returns bool
|
||||||
|
|
||||||
|
## Task 3: LLM Client Factory
|
||||||
|
|
||||||
|
- [x] 3.1 Create `services/extractor/llm_factory.py` with `build_llm_client()` function that returns `OllamaClient` or `VLLMClient` based on resolved `model_provider`
|
||||||
|
- [x] 3.2 Implement `build_config_from_resolved()` function that creates provider-specific config from `ResolvedAgentConfig` and base configs
|
||||||
|
- [x] 3.3 Handle unknown provider values: log warning and fall back to `OllamaClient`
|
||||||
|
|
||||||
|
## Task 4: Update Extractor Worker for Provider Abstraction
|
||||||
|
|
||||||
|
- [x] 4.1 Update `services/extractor/main.py` to import and use `build_llm_client()` from the factory instead of directly constructing `OllamaClient`
|
||||||
|
- [x] 4.2 Replace `_build_ollama_config_from_resolved()` usage with the factory's `build_config_from_resolved()` for both extractor and classifier clients
|
||||||
|
- [x] 4.3 Add vLLM health check call at startup when resolved config specifies `model_provider = "vllm"`, with fallback to Ollama on failure
|
||||||
|
- [x] 4.4 Update config refresh logic (every 100 jobs) to detect provider changes, close old client, and construct new client via factory
|
||||||
|
- [x] 4.5 Add INFO-level logging for provider switches including old/new provider, model name, and variant ID
|
||||||
|
|
||||||
|
## Task 5: Update Event Classifier for Provider Abstraction
|
||||||
|
|
||||||
|
- [x] 5.1 Update `classify_global_event()` in `services/extractor/event_classifier.py` to accept `LLMClient` protocol type instead of `Any` for the client parameter
|
||||||
|
- [x] 5.2 Replace `ollama_client._call_ollama()` calls with `client.call_llm()` calls
|
||||||
|
- [x] 5.3 Update `ModelMetadata.provider` assignment to use the actual provider string from the client (detect from config type or pass explicitly)
|
||||||
|
- [x] 5.4 Update retry logic to use client config attributes instead of accessing `ollama_client._base_delay` and `ollama_client._backoff_multiplier` directly
|
||||||
|
|
||||||
|
## Task 6: Helm Configuration
|
||||||
|
|
||||||
|
- [x] 6.1 Add `VLLM_BASE_URL`, `VLLM_MODEL`, `VLLM_TIMEOUT`, `VLLM_MAX_RETRIES`, `VLLM_TEMPERATURE`, and `VLLM_API_KEY` entries to the `config:` section in `infra/helm/stonks-oracle/values.yaml`
|
||||||
|
|
||||||
|
## Task 7: Unit Tests for VLLMClient
|
||||||
|
|
||||||
|
- [x] 7.1 Create `tests/test_vllm_client.py` with test for VLLMClient sending correct payload to `/v1/chat/completions` using mock httpx transport
|
||||||
|
- [x] 7.2 Add test for VLLMClient extracting content from `choices[0].message.content`
|
||||||
|
- [x] 7.3 Add test for VLLMClient handling empty choices array returning `empty_model_response` error
|
||||||
|
- [x] 7.4 Add test for VLLMClient handling HTTP timeout returning `timeout` error
|
||||||
|
- [x] 7.5 Add test for VLLMClient handling HTTP 500 returning `http_500` retryable error
|
||||||
|
- [x] 7.6 Add test for VLLMClient handling HTTP 400 returning `http_400` non-retryable error
|
||||||
|
- [x] 7.7 Add test for VLLMClient handling connection error returning `connection_error: ...`
|
||||||
|
- [x] 7.8 Add test for VLLMClient applying markdown fence stripping and JSON repair to response
|
||||||
|
- [x] 7.9 Add test for VLLMClient including temperature and response_format in payload
|
||||||
|
- [x] 7.10 Add test for health check success returning True and logging INFO
|
||||||
|
- [x] 7.11 Add test for health check failure returning False and logging WARNING
|
||||||
|
- [x] 7.12 Add test for OllamaClient.call_llm() delegating to _call_ollama()
|
||||||
|
- [x] 7.13 Add test for VLLMConfig loading from environment variables
|
||||||
|
- [x] 7.14 Add test for AppConfig including vllm field with correct defaults
|
||||||
|
|
||||||
|
## Task 8: Unit Tests for LLM Factory
|
||||||
|
|
||||||
|
- [x] 8.1 Add tests to `tests/test_vllm_client.py` for factory returning OllamaClient when provider is "ollama"
|
||||||
|
- [x] 8.2 Add test for factory returning VLLMClient when provider is "vllm"
|
||||||
|
- [x] 8.3 Add test for factory returning OllamaClient when provider is empty string (default)
|
||||||
|
- [x] 8.4 Add test for factory returning OllamaClient with warning when provider is unknown value
|
||||||
|
|
||||||
|
## Task 9: Property-Based Tests
|
||||||
|
|
||||||
|
- [x] 9.1 Create `tests/test_pbt_llm_provider.py` with property test for factory routing: for all model_provider in {"ollama", "vllm", "", None}, factory returns correct client type [PBT]
|
||||||
|
- [x] 9.2 Add property test for error string format consistency: for all HTTP status codes (100-599), `_is_retryable()` classifies them consistently [PBT]
|
||||||
|
- [x] 9.3 Add property test for VLLMClient request payload structure: for all generated prompt dicts, payload contains required OpenAI fields and excludes Ollama-specific fields [PBT]
|
||||||
|
- [x] 9.4 Add property test for JSON repair idempotence: for all valid JSON strings, `_repair_json()` is idempotent [PBT]
|
||||||
|
- [x] 9.5 Add property test for markdown fence stripping: for all strings, wrapping in fences then stripping recovers the original [PBT]
|
||||||
|
- [x] 9.6 Add property test for VLLMConfig defaults: for all default-constructed instances, invariants hold (timeout > 0, max_retries >= 0, 0 <= temperature <= 2, max_tokens > 0) [PBT]
|
||||||
|
|
||||||
|
## Task 10: Verification and Backward Compatibility
|
||||||
|
|
||||||
|
- [x] 10.1 Run existing `tests/test_ollama_client.py` to verify no regressions
|
||||||
|
- [x] 10.2 Run `ruff check services/` to verify no lint errors in modified files
|
||||||
|
- [x] 10.3 Run full test suite `python -m pytest tests/ -x --tb=short -q` to verify all tests pass
|
||||||
@@ -181,6 +181,12 @@ config:
|
|||||||
OLLAMA_RETRY_BASE_DELAY: "1.0"
|
OLLAMA_RETRY_BASE_DELAY: "1.0"
|
||||||
OLLAMA_RETRY_MAX_DELAY: "10.0"
|
OLLAMA_RETRY_MAX_DELAY: "10.0"
|
||||||
OLLAMA_RETRY_BACKOFF_MULTIPLIER: "2.0"
|
OLLAMA_RETRY_BACKOFF_MULTIPLIER: "2.0"
|
||||||
|
VLLM_BASE_URL: "http://192.168.42.254:8000"
|
||||||
|
VLLM_MODEL: "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||||||
|
VLLM_TIMEOUT: "120"
|
||||||
|
VLLM_MAX_RETRIES: "2"
|
||||||
|
VLLM_TEMPERATURE: "0.7"
|
||||||
|
VLLM_API_KEY: ""
|
||||||
TRINO_HOST: "trino.stonks-oracle.svc.cluster.local"
|
TRINO_HOST: "trino.stonks-oracle.svc.cluster.local"
|
||||||
TRINO_PORT: "8080"
|
TRINO_PORT: "8080"
|
||||||
TRINO_CATALOG: "lakehouse"
|
TRINO_CATALOG: "lakehouse"
|
||||||
|
|||||||
@@ -155,6 +155,19 @@ class OllamaClient:
|
|||||||
if self._owns_client:
|
if self._owns_client:
|
||||||
await self._http.aclose()
|
await self._http.aclose()
|
||||||
|
|
||||||
|
async def call_llm(
|
||||||
|
self,
|
||||||
|
prompts: dict[str, str],
|
||||||
|
json_schema: dict[str, object],
|
||||||
|
document_text: str = "",
|
||||||
|
) -> ExtractionAttempt:
|
||||||
|
"""Public LLM client interface — delegates to _call_ollama().
|
||||||
|
|
||||||
|
Satisfies the LLMClient protocol so OllamaClient can be used
|
||||||
|
interchangeably with VLLMClient.
|
||||||
|
"""
|
||||||
|
return await self._call_ollama(prompts, json_schema, document_text)
|
||||||
|
|
||||||
async def extract(
|
async def extract(
|
||||||
self,
|
self,
|
||||||
document_text: str,
|
document_text: str,
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
"""Event classifier module for macro news articles.
|
"""Event classifier module for macro news articles.
|
||||||
|
|
||||||
Classifies global/geopolitical news articles into structured GlobalEvent
|
Classifies global/geopolitical news articles into structured GlobalEvent
|
||||||
objects using Ollama with a dedicated prompt and JSON schema. Reuses the
|
objects using an LLM client (Ollama or vLLM) with a dedicated prompt and
|
||||||
existing OllamaClient for inference and retry logic.
|
JSON schema. Uses the LLMClient protocol for provider-agnostic inference
|
||||||
|
and retry logic.
|
||||||
|
|
||||||
Persists classification prompts, raw outputs, and final events to MinIO
|
Persists classification prompts, raw outputs, and final events to MinIO
|
||||||
and PostgreSQL for audit and downstream interpolation.
|
and PostgreSQL for audit and downstream interpolation.
|
||||||
|
|
||||||
Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
Requirements: 1.4, 2.1, 2.2, 2.3, 2.4, 2.5, 6.2
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -24,6 +25,8 @@ import asyncpg
|
|||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||||||
|
from services.shared.config import VLLMConfig
|
||||||
|
from services.shared.llm_protocol import LLMClient
|
||||||
from services.shared.schemas import (
|
from services.shared.schemas import (
|
||||||
EstimatedDuration,
|
EstimatedDuration,
|
||||||
ImpactType,
|
ImpactType,
|
||||||
@@ -281,6 +284,7 @@ def _parse_classification_response(
|
|||||||
raw_json: str,
|
raw_json: str,
|
||||||
document_id: str,
|
document_id: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
provider: str = "ollama",
|
||||||
) -> GlobalEvent:
|
) -> GlobalEvent:
|
||||||
"""Parse raw Ollama JSON output into a GlobalEvent.
|
"""Parse raw Ollama JSON output into a GlobalEvent.
|
||||||
|
|
||||||
@@ -345,7 +349,7 @@ def _parse_classification_response(
|
|||||||
confidence=confidence,
|
confidence=confidence,
|
||||||
source_document_id=document_id,
|
source_document_id=document_id,
|
||||||
model_metadata=ModelMetadata(
|
model_metadata=ModelMetadata(
|
||||||
provider="ollama",
|
provider=provider,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_version=PROMPT_VERSION,
|
prompt_version=PROMPT_VERSION,
|
||||||
schema_version=SCHEMA_VERSION,
|
schema_version=SCHEMA_VERSION,
|
||||||
@@ -479,21 +483,21 @@ async def persist_global_event(
|
|||||||
async def classify_global_event(
|
async def classify_global_event(
|
||||||
normalized_text: str,
|
normalized_text: str,
|
||||||
document_id: str,
|
document_id: str,
|
||||||
ollama_client: Any,
|
client: LLMClient,
|
||||||
*,
|
*,
|
||||||
pool: asyncpg.Pool | None = None,
|
pool: asyncpg.Pool | None = None,
|
||||||
minio_client: Minio | None = None,
|
minio_client: Minio | None = None,
|
||||||
) -> GlobalEvent:
|
) -> GlobalEvent:
|
||||||
"""Classify a macro news article into a GlobalEvent using Ollama.
|
"""Classify a macro news article into a GlobalEvent using an LLM.
|
||||||
|
|
||||||
Uses the existing OllamaClient's streaming infrastructure with a
|
Uses the LLMClient protocol's call_llm() method with a dedicated
|
||||||
dedicated event classification prompt and JSON schema. Follows the
|
event classification prompt and JSON schema. Follows the same retry
|
||||||
same retry policy as document extraction.
|
policy as document extraction.
|
||||||
|
|
||||||
Resolves runtime config for the "event-classifier" agent slug from
|
Resolves runtime config for the "event-classifier" agent slug from
|
||||||
the database, preferring an active variant's model_name and
|
the database, preferring an active variant's model_name and
|
||||||
system_prompt if one exists. Falls back to the OllamaClient's
|
system_prompt if one exists. Falls back to the client's existing
|
||||||
existing config if resolution fails.
|
config if resolution fails.
|
||||||
|
|
||||||
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
Persists prompt, raw output, and final event to MinIO and PostgreSQL
|
||||||
when the respective clients are provided.
|
when the respective clients are provided.
|
||||||
@@ -501,7 +505,7 @@ async def classify_global_event(
|
|||||||
Args:
|
Args:
|
||||||
normalized_text: Cleaned text content of the macro article.
|
normalized_text: Cleaned text content of the macro article.
|
||||||
document_id: UUID of the source document.
|
document_id: UUID of the source document.
|
||||||
ollama_client: An OllamaClient instance (from services.extractor.client).
|
client: An LLMClient instance (OllamaClient or VLLMClient).
|
||||||
pool: Optional asyncpg pool for PostgreSQL persistence.
|
pool: Optional asyncpg pool for PostgreSQL persistence.
|
||||||
minio_client: Optional MinIO client for artifact persistence.
|
minio_client: Optional MinIO client for artifact persistence.
|
||||||
|
|
||||||
@@ -528,7 +532,10 @@ async def classify_global_event(
|
|||||||
|
|
||||||
prompts = build_event_classification_prompt(normalized_text)
|
prompts = build_event_classification_prompt(normalized_text)
|
||||||
json_schema = get_event_json_schema()
|
json_schema = get_event_json_schema()
|
||||||
model_name = ollama_client._config.model
|
model_name = client._config.model
|
||||||
|
|
||||||
|
# Detect provider from client config type
|
||||||
|
provider = "vllm" if isinstance(client._config, VLLMConfig) else "ollama"
|
||||||
|
|
||||||
# Override model_name and system_prompt from resolved config
|
# Override model_name and system_prompt from resolved config
|
||||||
if resolved is not None:
|
if resolved is not None:
|
||||||
@@ -562,16 +569,16 @@ async def classify_global_event(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
logger.exception("Failed to upload classification prompt for doc %s", document_id)
|
||||||
|
|
||||||
# Call Ollama using the client's internal _call_ollama method
|
# Call LLM using the client's call_llm method
|
||||||
# We reuse the retry logic pattern from OllamaClient.extract()
|
# We reuse the retry logic pattern from OllamaClient.extract()
|
||||||
max_retries = ollama_client._max_retries
|
max_retries = client._config.max_retries
|
||||||
if resolved is not None:
|
if resolved is not None:
|
||||||
max_retries = resolved.max_retries
|
max_retries = resolved.max_retries
|
||||||
last_error: str | None = None
|
last_error: str | None = None
|
||||||
raw_output = ""
|
raw_output = ""
|
||||||
|
|
||||||
for attempt_num in range(max_retries + 1):
|
for attempt_num in range(max_retries + 1):
|
||||||
attempt = await ollama_client._call_ollama(prompts, json_schema)
|
attempt = await client.call_llm(prompts, json_schema)
|
||||||
raw_output = attempt.raw_output
|
raw_output = attempt.raw_output
|
||||||
|
|
||||||
# _call_ollama validates against the *extraction* schema, which
|
# _call_ollama validates against the *extraction* schema, which
|
||||||
@@ -581,7 +588,7 @@ async def classify_global_event(
|
|||||||
# Try to parse the response
|
# Try to parse the response
|
||||||
try:
|
try:
|
||||||
event = _parse_classification_response(
|
event = _parse_classification_response(
|
||||||
raw_output, document_id, model_name,
|
raw_output, document_id, model_name, provider=provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Persist result to MinIO
|
# Persist result to MinIO
|
||||||
@@ -648,10 +655,10 @@ async def classify_global_event(
|
|||||||
|
|
||||||
# Retry with backoff
|
# Retry with backoff
|
||||||
if attempt_num < max_retries:
|
if attempt_num < max_retries:
|
||||||
delay = ollama_client._base_delay * (
|
delay = client._config.retry_base_delay * (
|
||||||
ollama_client._backoff_multiplier ** attempt_num
|
client._config.retry_backoff_multiplier ** attempt_num
|
||||||
)
|
)
|
||||||
delay = min(delay, ollama_client._max_delay)
|
delay = min(delay, client._config.retry_max_delay)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
"Classification attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||||
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
attempt_num + 1, max_retries + 1, document_id, last_error, delay,
|
||||||
|
|||||||
@@ -0,0 +1,135 @@
|
|||||||
|
"""LLM client factory for provider-based routing.
|
||||||
|
|
||||||
|
Returns the appropriate LLM client (OllamaClient or VLLMClient) based on
|
||||||
|
the resolved ``model_provider`` from the agent config. Falls back to
|
||||||
|
OllamaClient for unknown or missing providers.
|
||||||
|
|
||||||
|
Requirements: 3.4, 3.5, 3.6, 9.5
|
||||||
|
Design: LLM Client Factory
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from services.extractor.client import OllamaClient
|
||||||
|
from services.extractor.vllm_client import VLLMClient
|
||||||
|
from services.shared.agent_config import ResolvedAgentConfig
|
||||||
|
from services.shared.config import OllamaConfig, VLLMConfig
|
||||||
|
from services.shared.llm_protocol import LLMClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Providers that map to OllamaClient (including empty / None).
|
||||||
|
_OLLAMA_PROVIDERS = frozenset({"ollama", "", None})
|
||||||
|
|
||||||
|
|
||||||
|
def build_config_from_resolved(
|
||||||
|
resolved: ResolvedAgentConfig,
|
||||||
|
base_ollama: OllamaConfig,
|
||||||
|
base_vllm: VLLMConfig,
|
||||||
|
) -> OllamaConfig | VLLMConfig:
|
||||||
|
"""Build a provider-specific config from a resolved agent config.
|
||||||
|
|
||||||
|
Merges the resolved agent-level overrides (model_name, timeout, retries,
|
||||||
|
max_tokens, context_window) with the base environment config (base_url,
|
||||||
|
retry delays, provider-specific defaults).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resolved: Runtime config resolved from the database.
|
||||||
|
base_ollama: Base OllamaConfig loaded from environment variables.
|
||||||
|
base_vllm: Base VLLMConfig loaded from environment variables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An ``OllamaConfig`` or ``VLLMConfig`` depending on the provider.
|
||||||
|
"""
|
||||||
|
provider = (resolved.model_provider or "").strip().lower()
|
||||||
|
|
||||||
|
if provider == "vllm":
|
||||||
|
return VLLMConfig(
|
||||||
|
base_url=base_vllm.base_url,
|
||||||
|
model=resolved.model_name,
|
||||||
|
timeout=resolved.timeout_seconds,
|
||||||
|
max_retries=resolved.max_retries,
|
||||||
|
retry_base_delay=base_vllm.retry_base_delay,
|
||||||
|
retry_max_delay=base_vllm.retry_max_delay,
|
||||||
|
retry_backoff_multiplier=base_vllm.retry_backoff_multiplier,
|
||||||
|
max_tokens=resolved.max_tokens,
|
||||||
|
temperature=base_vllm.temperature,
|
||||||
|
api_key=base_vllm.api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Default: Ollama config (covers "ollama", "", None, and unknown)
|
||||||
|
if provider not in _OLLAMA_PROVIDERS:
|
||||||
|
logger.warning(
|
||||||
|
"Unknown model_provider %r for agent %s — treating as ollama",
|
||||||
|
resolved.model_provider,
|
||||||
|
resolved.agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OllamaConfig(
|
||||||
|
base_url=base_ollama.base_url,
|
||||||
|
model=resolved.model_name,
|
||||||
|
timeout=resolved.timeout_seconds,
|
||||||
|
max_retries=resolved.max_retries,
|
||||||
|
retry_base_delay=base_ollama.retry_base_delay,
|
||||||
|
retry_max_delay=base_ollama.retry_max_delay,
|
||||||
|
retry_backoff_multiplier=base_ollama.retry_backoff_multiplier,
|
||||||
|
max_tokens=resolved.max_tokens,
|
||||||
|
stall_timeout=base_ollama.stall_timeout,
|
||||||
|
loop_window=base_ollama.loop_window,
|
||||||
|
loop_threshold=base_ollama.loop_threshold,
|
||||||
|
context_window=resolved.context_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_llm_client(
|
||||||
|
resolved: ResolvedAgentConfig | None,
|
||||||
|
ollama_config: OllamaConfig,
|
||||||
|
vllm_config: VLLMConfig,
|
||||||
|
http_client: httpx.AsyncClient | None = None,
|
||||||
|
) -> LLMClient:
|
||||||
|
"""Return the appropriate LLM client based on the resolved provider.
|
||||||
|
|
||||||
|
Provider routing:
|
||||||
|
- ``None`` / ``""`` / ``"ollama"`` → :class:`OllamaClient`
|
||||||
|
- ``"vllm"`` → :class:`VLLMClient`
|
||||||
|
- Unknown value → log warning, fall back to :class:`OllamaClient`
|
||||||
|
|
||||||
|
When *resolved* is ``None`` (DB lookup failed), the base
|
||||||
|
``ollama_config`` is used directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resolved: Resolved agent config (may be ``None``).
|
||||||
|
ollama_config: Base OllamaConfig from environment.
|
||||||
|
vllm_config: Base VLLMConfig from environment.
|
||||||
|
http_client: Optional shared httpx client for testing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An LLM client satisfying the :class:`LLMClient` protocol.
|
||||||
|
"""
|
||||||
|
if resolved is None:
|
||||||
|
logger.info("No resolved agent config — defaulting to OllamaClient")
|
||||||
|
return OllamaClient(ollama_config, http_client=http_client)
|
||||||
|
|
||||||
|
provider = (resolved.model_provider or "").strip().lower()
|
||||||
|
|
||||||
|
if provider == "vllm":
|
||||||
|
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
|
||||||
|
logger.info(
|
||||||
|
"Building VLLMClient for agent %s (model=%s)",
|
||||||
|
resolved.agent_id,
|
||||||
|
cfg.model, # type: ignore[union-attr]
|
||||||
|
)
|
||||||
|
return VLLMClient(cfg, http_client=http_client) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
if provider not in _OLLAMA_PROVIDERS:
|
||||||
|
logger.warning(
|
||||||
|
"Unknown model_provider %r for agent %s — falling back to OllamaClient",
|
||||||
|
resolved.model_provider,
|
||||||
|
resolved.agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = build_config_from_resolved(resolved, ollama_config, vllm_config)
|
||||||
|
return OllamaClient(cfg, http_client=http_client) # type: ignore[arg-type]
|
||||||
+106
-46
@@ -15,11 +15,13 @@ from services.aggregation.interpolation import (
|
|||||||
filter_low_confidence_events,
|
filter_low_confidence_events,
|
||||||
persist_macro_impact_records,
|
persist_macro_impact_records,
|
||||||
)
|
)
|
||||||
from services.extractor.client import OllamaClient
|
|
||||||
from services.extractor.event_classifier import classify_global_event
|
from services.extractor.event_classifier import classify_global_event
|
||||||
|
from services.extractor.llm_factory import build_config_from_resolved, build_llm_client
|
||||||
|
from services.extractor.vllm_client import check_vllm_health
|
||||||
from services.extractor.worker import persist_extraction
|
from services.extractor.worker import persist_extraction
|
||||||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||||||
from services.shared.config import OllamaConfig, load_config
|
from services.shared.config import OllamaConfig, load_config
|
||||||
|
from services.shared.llm_protocol import LLMClient
|
||||||
from services.shared.logging import inject_trace_context, setup_logging
|
from services.shared.logging import inject_trace_context, setup_logging
|
||||||
from services.shared.redis_keys import (
|
from services.shared.redis_keys import (
|
||||||
QUEUE_AGGREGATION,
|
QUEUE_AGGREGATION,
|
||||||
@@ -31,11 +33,22 @@ from services.shared.redis_keys import (
|
|||||||
logger = logging.getLogger("extractor_main")
|
logger = logging.getLogger("extractor_main")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_provider(resolved: ResolvedAgentConfig | None) -> str:
|
||||||
|
"""Return the normalised provider string for a resolved config."""
|
||||||
|
if resolved is None:
|
||||||
|
return "ollama"
|
||||||
|
return (resolved.model_provider or "").strip().lower() or "ollama"
|
||||||
|
|
||||||
|
|
||||||
def _build_ollama_config_from_resolved(
|
def _build_ollama_config_from_resolved(
|
||||||
resolved: ResolvedAgentConfig,
|
resolved: ResolvedAgentConfig,
|
||||||
base_config: OllamaConfig,
|
base_config: OllamaConfig,
|
||||||
) -> OllamaConfig:
|
) -> OllamaConfig:
|
||||||
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings."""
|
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings.
|
||||||
|
|
||||||
|
Kept for backward compatibility — the factory's ``build_config_from_resolved``
|
||||||
|
is now the primary path.
|
||||||
|
"""
|
||||||
return OllamaConfig(
|
return OllamaConfig(
|
||||||
base_url=base_config.base_url,
|
base_url=base_config.base_url,
|
||||||
model=resolved.model_name,
|
model=resolved.model_name,
|
||||||
@@ -239,7 +252,7 @@ async def _process_macro_classification(
|
|||||||
*,
|
*,
|
||||||
pool: asyncpg.Pool,
|
pool: asyncpg.Pool,
|
||||||
minio_client: Minio,
|
minio_client: Minio,
|
||||||
ollama: OllamaClient,
|
ollama: LLMClient,
|
||||||
redis_client: aioredis.Redis,
|
redis_client: aioredis.Redis,
|
||||||
document_id: str,
|
document_id: str,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -258,7 +271,7 @@ async def _process_macro_classification(
|
|||||||
event = await classify_global_event(
|
event = await classify_global_event(
|
||||||
normalized_text=text,
|
normalized_text=text,
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
ollama_client=ollama,
|
client=ollama,
|
||||||
pool=pool,
|
pool=pool,
|
||||||
minio_client=minio_client,
|
minio_client=minio_client,
|
||||||
)
|
)
|
||||||
@@ -329,48 +342,69 @@ async def main() -> None:
|
|||||||
# Resolve extractor config from DB (active variant override + TTL cache)
|
# Resolve extractor config from DB (active variant override + TTL cache)
|
||||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||||
resolved_config: ResolvedAgentConfig | None = None
|
resolved_config: ResolvedAgentConfig | None = None
|
||||||
extractor_ollama_config = config.ollama
|
extractor_provider = "ollama"
|
||||||
try:
|
try:
|
||||||
resolved_config = await resolver.resolve("document-extractor")
|
resolved_config = await resolver.resolve("document-extractor")
|
||||||
if resolved_config is not None:
|
if resolved_config is not None:
|
||||||
extractor_ollama_config = _build_ollama_config_from_resolved(
|
extractor_provider = _get_provider(resolved_config)
|
||||||
resolved_config, config.ollama,
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Extractor using resolved config: model=%s variant=%s",
|
"Extractor using resolved config: model=%s variant=%s provider=%s",
|
||||||
resolved_config.model_name, resolved_config.variant_id,
|
resolved_config.model_name, resolved_config.variant_id, extractor_provider,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("No DB config for document-extractor — using env defaults")
|
logger.info("No DB config for document-extractor — using env defaults")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
|
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
|
||||||
|
|
||||||
ollama = OllamaClient(extractor_ollama_config)
|
# vLLM health check at startup when provider is vllm (Requirement 7.1–7.3)
|
||||||
|
if extractor_provider == "vllm":
|
||||||
|
healthy = await check_vllm_health(config.vllm.base_url)
|
||||||
|
if not healthy:
|
||||||
|
logger.warning(
|
||||||
|
"vLLM health check failed at startup — falling back to Ollama for extractor",
|
||||||
|
)
|
||||||
|
extractor_provider = "ollama"
|
||||||
|
# Override resolved config provider so factory builds OllamaClient
|
||||||
|
resolved_config = None
|
||||||
|
|
||||||
|
extractor_client: LLMClient = build_llm_client(
|
||||||
|
resolved_config, config.ollama, config.vllm,
|
||||||
|
)
|
||||||
|
|
||||||
# Resolve event classifier config separately (may use different model)
|
# Resolve event classifier config separately (may use different model)
|
||||||
classifier_resolved: ResolvedAgentConfig | None = None
|
classifier_resolved: ResolvedAgentConfig | None = None
|
||||||
classifier_ollama_config = config.ollama
|
classifier_provider = "ollama"
|
||||||
try:
|
try:
|
||||||
classifier_resolved = await resolver.resolve("event-classifier")
|
classifier_resolved = await resolver.resolve("event-classifier")
|
||||||
if classifier_resolved is not None:
|
if classifier_resolved is not None:
|
||||||
classifier_ollama_config = _build_ollama_config_from_resolved(
|
classifier_provider = _get_provider(classifier_resolved)
|
||||||
classifier_resolved, config.ollama,
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Event classifier using resolved config: model=%s variant=%s",
|
"Event classifier using resolved config: model=%s variant=%s provider=%s",
|
||||||
classifier_resolved.model_name, classifier_resolved.variant_id,
|
classifier_resolved.model_name, classifier_resolved.variant_id, classifier_provider,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("No DB config for event-classifier — using extractor config")
|
logger.info("No DB config for event-classifier — using extractor config")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
|
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
|
||||||
|
|
||||||
# Use a separate OllamaClient for the classifier if it has a different model
|
# vLLM health check for classifier if it uses vllm and extractor didn't already check
|
||||||
classifier_ollama: OllamaClient
|
if classifier_provider == "vllm" and extractor_provider != "vllm":
|
||||||
if classifier_ollama_config.model != extractor_ollama_config.model:
|
healthy = await check_vllm_health(config.vllm.base_url)
|
||||||
classifier_ollama = OllamaClient(classifier_ollama_config)
|
if not healthy:
|
||||||
|
logger.warning(
|
||||||
|
"vLLM health check failed at startup — falling back to Ollama for classifier",
|
||||||
|
)
|
||||||
|
classifier_provider = "ollama"
|
||||||
|
classifier_resolved = None
|
||||||
|
|
||||||
|
# Build classifier client — share with extractor when configs match
|
||||||
|
classifier_client: LLMClient
|
||||||
|
if classifier_resolved is not None or classifier_provider != extractor_provider:
|
||||||
|
classifier_client = build_llm_client(
|
||||||
|
classifier_resolved, config.ollama, config.vllm,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
classifier_ollama = ollama
|
classifier_client = extractor_client
|
||||||
|
|
||||||
redis_client = aioredis.from_url(config.redis.url)
|
redis_client = aioredis.from_url(config.redis.url)
|
||||||
queue = queue_key(QUEUE_EXTRACTION)
|
queue = queue_key(QUEUE_EXTRACTION)
|
||||||
@@ -441,40 +475,66 @@ async def main() -> None:
|
|||||||
company_id_map = await _build_company_id_map(pool)
|
company_id_map = await _build_company_id_map(pool)
|
||||||
# Re-resolve extractor config (picks up active variant swaps)
|
# Re-resolve extractor config (picks up active variant swaps)
|
||||||
try:
|
try:
|
||||||
resolved_config = await resolver.resolve("document-extractor")
|
new_resolved = await resolver.resolve("document-extractor")
|
||||||
if resolved_config is not None:
|
if new_resolved is not None:
|
||||||
new_ollama_cfg = _build_ollama_config_from_resolved(
|
new_provider = _get_provider(new_resolved)
|
||||||
resolved_config, config.ollama,
|
new_cfg = build_config_from_resolved(
|
||||||
|
new_resolved, config.ollama, config.vllm,
|
||||||
)
|
)
|
||||||
if new_ollama_cfg.model != ollama._config.model:
|
old_provider = extractor_provider
|
||||||
|
provider_changed = new_provider != extractor_provider
|
||||||
|
model_changed = new_cfg.model != extractor_client._config.model
|
||||||
|
|
||||||
|
if provider_changed or model_changed:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Extractor config changed: model=%s variant=%s",
|
"Extractor provider switch: old_provider=%s new_provider=%s "
|
||||||
resolved_config.model_name, resolved_config.variant_id,
|
"model=%s variant=%s",
|
||||||
|
old_provider, new_provider,
|
||||||
|
new_resolved.model_name, new_resolved.variant_id,
|
||||||
)
|
)
|
||||||
await ollama.close()
|
await extractor_client.close()
|
||||||
ollama = OllamaClient(new_ollama_cfg)
|
extractor_client = build_llm_client(
|
||||||
|
new_resolved, config.ollama, config.vllm,
|
||||||
|
)
|
||||||
|
extractor_provider = new_provider
|
||||||
else:
|
else:
|
||||||
ollama._config = new_ollama_cfg
|
# Same provider and model — just update config in-place
|
||||||
|
extractor_client._config = new_cfg # type: ignore[assignment]
|
||||||
|
resolved_config = new_resolved
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to refresh extractor config", exc_info=True)
|
logger.warning("Failed to refresh extractor config", exc_info=True)
|
||||||
|
|
||||||
# Re-resolve event classifier config
|
# Re-resolve event classifier config
|
||||||
try:
|
try:
|
||||||
classifier_resolved = await resolver.resolve("event-classifier")
|
new_cls_resolved = await resolver.resolve("event-classifier")
|
||||||
if classifier_resolved is not None:
|
if new_cls_resolved is not None:
|
||||||
new_cls_cfg = _build_ollama_config_from_resolved(
|
new_cls_provider = _get_provider(new_cls_resolved)
|
||||||
classifier_resolved, config.ollama,
|
new_cls_cfg = build_config_from_resolved(
|
||||||
|
new_cls_resolved, config.ollama, config.vllm,
|
||||||
)
|
)
|
||||||
if new_cls_cfg.model != classifier_ollama._config.model:
|
old_cls_provider = classifier_provider
|
||||||
|
cls_provider_changed = new_cls_provider != classifier_provider
|
||||||
|
cls_model_changed = new_cls_cfg.model != classifier_client._config.model
|
||||||
|
|
||||||
|
if cls_provider_changed or cls_model_changed:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Event classifier config changed: model=%s variant=%s",
|
"Classifier provider switch: old_provider=%s new_provider=%s "
|
||||||
classifier_resolved.model_name, classifier_resolved.variant_id,
|
"model=%s variant=%s",
|
||||||
|
old_cls_provider, new_cls_provider,
|
||||||
|
new_cls_resolved.model_name, new_cls_resolved.variant_id,
|
||||||
)
|
)
|
||||||
if classifier_ollama is not ollama:
|
if classifier_client is not extractor_client:
|
||||||
await classifier_ollama.close()
|
await classifier_client.close()
|
||||||
classifier_ollama = OllamaClient(new_cls_cfg)
|
classifier_client = build_llm_client(
|
||||||
elif classifier_ollama is ollama and new_cls_cfg.model != ollama._config.model:
|
new_cls_resolved, config.ollama, config.vllm,
|
||||||
classifier_ollama = OllamaClient(new_cls_cfg)
|
)
|
||||||
|
classifier_provider = new_cls_provider
|
||||||
|
elif classifier_client is extractor_client and new_cls_cfg.model != extractor_client._config.model:
|
||||||
|
classifier_client = build_llm_client(
|
||||||
|
new_cls_resolved, config.ollama, config.vllm,
|
||||||
|
)
|
||||||
|
classifier_provider = new_cls_provider
|
||||||
|
classifier_resolved = new_cls_resolved
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to refresh event-classifier config", exc_info=True)
|
logger.warning("Failed to refresh event-classifier config", exc_info=True)
|
||||||
|
|
||||||
@@ -490,7 +550,7 @@ async def main() -> None:
|
|||||||
await _process_macro_classification(
|
await _process_macro_classification(
|
||||||
pool=pool,
|
pool=pool,
|
||||||
minio_client=minio_client,
|
minio_client=minio_client,
|
||||||
ollama=classifier_ollama,
|
ollama=classifier_client,
|
||||||
redis_client=redis_client,
|
redis_client=redis_client,
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
text=text,
|
text=text,
|
||||||
@@ -529,7 +589,7 @@ async def main() -> None:
|
|||||||
|
|
||||||
# Pass all tracked tickers so the model can identify any mentioned companies
|
# Pass all tracked tickers so the model can identify any mentioned companies
|
||||||
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
||||||
extraction_response = await ollama.extract(
|
extraction_response = await extractor_client.extract(
|
||||||
extraction_text,
|
extraction_text,
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
known_tickers=all_tickers,
|
known_tickers=all_tickers,
|
||||||
|
|||||||
@@ -0,0 +1,177 @@
|
|||||||
|
"""vLLM client for OpenAI-compatible chat completions.
|
||||||
|
|
||||||
|
Sends structured extraction requests to a remote vLLM server via the
|
||||||
|
``/v1/chat/completions`` endpoint. Reuses the same markdown-fence
|
||||||
|
stripping, JSON repair, and error-string conventions as OllamaClient
|
||||||
|
so that ``_is_retryable()`` works without modification.
|
||||||
|
|
||||||
|
Requirements: 2.1–2.10, 7.1–7.4
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from services.extractor.client import (
|
||||||
|
ExtractionAttempt,
|
||||||
|
_is_retryable,
|
||||||
|
_repair_json,
|
||||||
|
_strip_markdown_fences,
|
||||||
|
)
|
||||||
|
from services.extractor.schemas import validate_extraction
|
||||||
|
from services.shared.config import VLLMConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger("vllm_client")
|
||||||
|
|
||||||
|
|
||||||
|
class VLLMClient:
|
||||||
|
"""Async client for vLLM OpenAI-compatible chat completions.
|
||||||
|
|
||||||
|
Satisfies the ``LLMClient`` protocol defined in
|
||||||
|
``services.shared.llm_protocol``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_config: VLLMConfig
|
||||||
|
_http: httpx.AsyncClient
|
||||||
|
_owns_client: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: VLLMConfig,
|
||||||
|
http_client: httpx.AsyncClient | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._owns_client = http_client is None
|
||||||
|
self._http = http_client or httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(config.timeout, read=config.timeout),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# LLMClient protocol
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def call_llm(
|
||||||
|
self,
|
||||||
|
prompts: dict[str, str],
|
||||||
|
json_schema: dict[str, object],
|
||||||
|
document_text: str = "",
|
||||||
|
) -> ExtractionAttempt:
|
||||||
|
"""Send a chat completion request to the vLLM server.
|
||||||
|
|
||||||
|
Builds an OpenAI-compatible payload, posts to
|
||||||
|
``/v1/chat/completions``, and parses the response through the
|
||||||
|
same markdown-fence / JSON-repair pipeline used by OllamaClient.
|
||||||
|
"""
|
||||||
|
attempt = ExtractionAttempt(model=self._config.model)
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
if self._config.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self._config.api_key}"
|
||||||
|
|
||||||
|
payload: dict[str, object] = {
|
||||||
|
"model": self._config.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": prompts["system"]},
|
||||||
|
{"role": "user", "content": prompts["user"]},
|
||||||
|
],
|
||||||
|
"max_tokens": self._config.max_tokens,
|
||||||
|
"temperature": self._config.temperature,
|
||||||
|
"response_format": {"type": "json_object"},
|
||||||
|
}
|
||||||
|
|
||||||
|
url = f"{self._config.base_url}/v1/chat/completions"
|
||||||
|
logger.info(
|
||||||
|
"vLLM POST %s model=%s input_chars=%d",
|
||||||
|
url,
|
||||||
|
self._config.model,
|
||||||
|
len(prompts.get("user", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await self._http.post(url, json=payload, headers=headers)
|
||||||
|
resp.raise_for_status()
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
attempt.error = "timeout"
|
||||||
|
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
return attempt
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
attempt.error = f"http_{exc.response.status_code}"
|
||||||
|
attempt.retryable = _is_retryable(attempt.error)
|
||||||
|
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
return attempt
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
attempt.error = f"connection_error: {exc}"
|
||||||
|
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
return attempt
|
||||||
|
|
||||||
|
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
|
||||||
|
# --- Parse the OpenAI-compatible response ---
|
||||||
|
try:
|
||||||
|
data = resp.json()
|
||||||
|
except Exception:
|
||||||
|
attempt.error = "invalid_response_json"
|
||||||
|
attempt.raw_output = resp.text[:2000]
|
||||||
|
return attempt
|
||||||
|
|
||||||
|
choices = data.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
attempt.error = "empty_model_response"
|
||||||
|
return attempt
|
||||||
|
|
||||||
|
content = (
|
||||||
|
choices[0].get("message", {}).get("content", "")
|
||||||
|
if isinstance(choices[0], dict)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
attempt.raw_output = content
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
attempt.error = "empty_model_response"
|
||||||
|
return attempt
|
||||||
|
|
||||||
|
# Strip markdown fences if present
|
||||||
|
content = _strip_markdown_fences(content)
|
||||||
|
|
||||||
|
# Repair malformed JSON
|
||||||
|
content = _repair_json(content)
|
||||||
|
|
||||||
|
# Validate against extraction schema
|
||||||
|
attempt.validation = validate_extraction(content, document_text=document_text)
|
||||||
|
if not attempt.validation.valid:
|
||||||
|
attempt.error = "; ".join(attempt.validation.errors)
|
||||||
|
|
||||||
|
return attempt
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Release the underlying ``httpx.AsyncClient`` if we own it."""
|
||||||
|
if self._owns_client:
|
||||||
|
await self._http.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Standalone health check
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def check_vllm_health(base_url: str, timeout: float = 10.0) -> bool:
|
||||||
|
"""Verify the vLLM server is reachable by querying ``/v1/models``.
|
||||||
|
|
||||||
|
Returns ``True`` when the server responds with HTTP 200, ``False``
|
||||||
|
otherwise. Logs INFO on success and WARNING on failure.
|
||||||
|
|
||||||
|
Requirements: 7.1–7.4
|
||||||
|
"""
|
||||||
|
url = f"{base_url}/v1/models"
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=httpx.Timeout(timeout)) as client:
|
||||||
|
resp = await client.get(url)
|
||||||
|
resp.raise_for_status()
|
||||||
|
logger.info("vLLM health check passed: %s", url)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("vLLM health check failed for %s: %s", url, exc)
|
||||||
|
return False
|
||||||
@@ -54,6 +54,24 @@ class OllamaConfig:
|
|||||||
context_window: int = 0 # Ollama num_ctx; 0 = use model default
|
context_window: int = 0 # Ollama num_ctx; 0 = use model default
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VLLMConfig:
|
||||||
|
"""Configuration for the remote vLLM inference server.
|
||||||
|
|
||||||
|
Requirements: 3.1, 3.2
|
||||||
|
"""
|
||||||
|
base_url: str = "http://192.168.42.254:8000"
|
||||||
|
model: str = "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||||||
|
timeout: int = 120
|
||||||
|
max_retries: int = 2
|
||||||
|
retry_base_delay: float = 1.0
|
||||||
|
retry_max_delay: float = 10.0
|
||||||
|
retry_backoff_multiplier: float = 2.0
|
||||||
|
max_tokens: int = 32768
|
||||||
|
temperature: float = 0.7
|
||||||
|
api_key: str = "" # Optional, for authenticated vLLM deployments
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrinoConfig:
|
class TrinoConfig:
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
@@ -217,6 +235,7 @@ class AppConfig:
|
|||||||
redis: RedisConfig = field(default_factory=RedisConfig)
|
redis: RedisConfig = field(default_factory=RedisConfig)
|
||||||
minio: MinioConfig = field(default_factory=MinioConfig)
|
minio: MinioConfig = field(default_factory=MinioConfig)
|
||||||
ollama: OllamaConfig = field(default_factory=OllamaConfig)
|
ollama: OllamaConfig = field(default_factory=OllamaConfig)
|
||||||
|
vllm: VLLMConfig = field(default_factory=VLLMConfig)
|
||||||
trino: TrinoConfig = field(default_factory=TrinoConfig)
|
trino: TrinoConfig = field(default_factory=TrinoConfig)
|
||||||
market_data: MarketDataConfig = field(default_factory=MarketDataConfig)
|
market_data: MarketDataConfig = field(default_factory=MarketDataConfig)
|
||||||
broker: BrokerConfig = field(default_factory=BrokerConfig)
|
broker: BrokerConfig = field(default_factory=BrokerConfig)
|
||||||
@@ -260,6 +279,18 @@ def load_config() -> AppConfig:
|
|||||||
retry_max_delay=float(os.getenv("OLLAMA_RETRY_MAX_DELAY", "10.0")),
|
retry_max_delay=float(os.getenv("OLLAMA_RETRY_MAX_DELAY", "10.0")),
|
||||||
retry_backoff_multiplier=float(os.getenv("OLLAMA_RETRY_BACKOFF_MULTIPLIER", "2.0")),
|
retry_backoff_multiplier=float(os.getenv("OLLAMA_RETRY_BACKOFF_MULTIPLIER", "2.0")),
|
||||||
),
|
),
|
||||||
|
vllm=VLLMConfig(
|
||||||
|
base_url=os.getenv("VLLM_BASE_URL", "http://192.168.42.254:8000"),
|
||||||
|
model=os.getenv("VLLM_MODEL", "RedHatAI/Qwen3.6-35B-A3B-NVFP4"),
|
||||||
|
timeout=int(os.getenv("VLLM_TIMEOUT", "120")),
|
||||||
|
max_retries=int(os.getenv("VLLM_MAX_RETRIES", "2")),
|
||||||
|
retry_base_delay=float(os.getenv("VLLM_RETRY_BASE_DELAY", "1.0")),
|
||||||
|
retry_max_delay=float(os.getenv("VLLM_RETRY_MAX_DELAY", "10.0")),
|
||||||
|
retry_backoff_multiplier=float(os.getenv("VLLM_RETRY_BACKOFF_MULTIPLIER", "2.0")),
|
||||||
|
max_tokens=int(os.getenv("VLLM_MAX_TOKENS", "32768")),
|
||||||
|
temperature=float(os.getenv("VLLM_TEMPERATURE", "0.7")),
|
||||||
|
api_key=os.getenv("VLLM_API_KEY", ""),
|
||||||
|
),
|
||||||
trino=TrinoConfig(
|
trino=TrinoConfig(
|
||||||
host=os.getenv("TRINO_HOST", "localhost"),
|
host=os.getenv("TRINO_HOST", "localhost"),
|
||||||
port=int(os.getenv("TRINO_PORT", "8080")),
|
port=int(os.getenv("TRINO_PORT", "8080")),
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
"""LLM client protocol for provider abstraction.
|
||||||
|
|
||||||
|
Defines the structural interface that both OllamaClient and VLLMClient
|
||||||
|
must satisfy, using typing.Protocol for duck-typing compatibility.
|
||||||
|
|
||||||
|
Requirements: 1.1, 1.2
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from services.extractor.client import ExtractionAttempt
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class LLMClient(Protocol):
|
||||||
|
"""Protocol defining the contract for LLM inference clients.
|
||||||
|
|
||||||
|
Both OllamaClient and VLLMClient satisfy this protocol via
|
||||||
|
structural subtyping — no inheritance required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def call_llm(
|
||||||
|
self,
|
||||||
|
prompts: dict[str, str],
|
||||||
|
json_schema: dict[str, object],
|
||||||
|
document_text: str = "",
|
||||||
|
) -> ExtractionAttempt:
|
||||||
|
"""Send a chat completion request and return an extraction attempt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: Dict with 'system' and 'user' prompt strings.
|
||||||
|
json_schema: JSON schema hint for structured output.
|
||||||
|
document_text: Optional raw document text for context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An ExtractionAttempt with raw output, validation, and error info.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Release underlying HTTP resources."""
|
||||||
|
...
|
||||||
@@ -274,19 +274,19 @@ class TestParseClassificationResponse:
|
|||||||
|
|
||||||
class TestClassifyGlobalEvent:
|
class TestClassifyGlobalEvent:
|
||||||
def _make_mock_client(self, raw_output: str, error: str | None = None):
|
def _make_mock_client(self, raw_output: str, error: str | None = None):
|
||||||
"""Create a mock OllamaClient with configurable response."""
|
"""Create a mock LLMClient with configurable response."""
|
||||||
client = MagicMock()
|
client = MagicMock()
|
||||||
client._config = MagicMock()
|
client._config = MagicMock()
|
||||||
client._config.model = "llama3.1:8b"
|
client._config.model = "llama3.1:8b"
|
||||||
client._max_retries = 2
|
client._config.max_retries = 2
|
||||||
client._base_delay = 0.01
|
client._config.retry_base_delay = 0.01
|
||||||
client._max_delay = 0.1
|
client._config.retry_max_delay = 0.1
|
||||||
client._backoff_multiplier = 2.0
|
client._config.retry_backoff_multiplier = 2.0
|
||||||
|
|
||||||
attempt = MagicMock()
|
attempt = MagicMock()
|
||||||
attempt.raw_output = raw_output
|
attempt.raw_output = raw_output
|
||||||
attempt.error = error
|
attempt.error = error
|
||||||
client._call_ollama = AsyncMock(return_value=attempt)
|
client.call_llm = AsyncMock(return_value=attempt)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -314,7 +314,7 @@ class TestClassifyGlobalEvent:
|
|||||||
assert event.severity == "critical"
|
assert event.severity == "critical"
|
||||||
assert event.confidence == 0.9
|
assert event.confidence == 0.9
|
||||||
assert event.source_document_id == "doc-123"
|
assert event.source_document_id == "doc-123"
|
||||||
client._call_ollama.assert_called_once()
|
client.call_llm.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_retries_on_error(self):
|
async def test_retries_on_error(self):
|
||||||
@@ -340,11 +340,11 @@ class TestClassifyGlobalEvent:
|
|||||||
success_attempt.error = None
|
success_attempt.error = None
|
||||||
|
|
||||||
client = self._make_mock_client("")
|
client = self._make_mock_client("")
|
||||||
client._call_ollama = AsyncMock(side_effect=[fail_attempt, success_attempt])
|
client.call_llm = AsyncMock(side_effect=[fail_attempt, success_attempt])
|
||||||
|
|
||||||
event = await classify_global_event("text", "doc-456", client)
|
event = await classify_global_event("text", "doc-456", client)
|
||||||
assert event.severity == "high"
|
assert event.severity == "high"
|
||||||
assert client._call_ollama.call_count == 2
|
assert client.call_llm.call_count == 2
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_raises_after_exhausted_retries(self):
|
async def test_raises_after_exhausted_retries(self):
|
||||||
@@ -353,12 +353,12 @@ class TestClassifyGlobalEvent:
|
|||||||
fail_attempt.error = "timeout"
|
fail_attempt.error = "timeout"
|
||||||
|
|
||||||
client = self._make_mock_client("")
|
client = self._make_mock_client("")
|
||||||
client._call_ollama = AsyncMock(return_value=fail_attempt)
|
client.call_llm = AsyncMock(return_value=fail_attempt)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Event classification failed"):
|
with pytest.raises(ValueError, match="Event classification failed"):
|
||||||
await classify_global_event("text", "doc-789", client)
|
await classify_global_event("text", "doc-789", client)
|
||||||
|
|
||||||
assert client._call_ollama.call_count == 3 # initial + 2 retries
|
assert client.call_llm.call_count == 3 # initial + 2 retries
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_minio_persistence_called(self):
|
async def test_minio_persistence_called(self):
|
||||||
|
|||||||
@@ -0,0 +1,296 @@
|
|||||||
|
"""Property-based tests for the LLM provider abstraction layer.
|
||||||
|
|
||||||
|
Feature: remote-vllm-support
|
||||||
|
|
||||||
|
Uses Hypothesis to validate correctness properties of the provider
|
||||||
|
abstraction: factory routing, error classification consistency,
|
||||||
|
VLLMClient payload structure, JSON repair idempotence, markdown
|
||||||
|
fence stripping round-trip, and VLLMConfig default invariants.
|
||||||
|
|
||||||
|
Requirements: 2.1, 2.3, 2.4, 3.1, 3.4, 3.5, 5.6, 8.1, 9.5
|
||||||
|
Design: Correctness Properties P1–P6
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from hypothesis import given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from services.extractor.client import (
|
||||||
|
OllamaClient,
|
||||||
|
_is_retryable,
|
||||||
|
_repair_json,
|
||||||
|
_strip_markdown_fences,
|
||||||
|
)
|
||||||
|
from services.extractor.llm_factory import build_llm_client
|
||||||
|
from services.extractor.vllm_client import VLLMClient
|
||||||
|
from services.shared.agent_config import ResolvedAgentConfig
|
||||||
|
from services.shared.config import OllamaConfig, VLLMConfig
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ollama_config() -> OllamaConfig:
|
||||||
|
return OllamaConfig(
|
||||||
|
base_url="http://test-ollama:11434",
|
||||||
|
model="test-ollama-model",
|
||||||
|
timeout=10,
|
||||||
|
retry_base_delay=0.0,
|
||||||
|
retry_max_delay=0.0,
|
||||||
|
retry_backoff_multiplier=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vllm_config() -> VLLMConfig:
|
||||||
|
return VLLMConfig(
|
||||||
|
base_url="http://test-vllm:8000",
|
||||||
|
model="test-vllm-model",
|
||||||
|
timeout=10,
|
||||||
|
max_retries=2,
|
||||||
|
retry_base_delay=0.0,
|
||||||
|
retry_max_delay=0.0,
|
||||||
|
retry_backoff_multiplier=2.0,
|
||||||
|
max_tokens=4096,
|
||||||
|
temperature=0.7,
|
||||||
|
api_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_resolved(provider: str | None) -> ResolvedAgentConfig:
|
||||||
|
return ResolvedAgentConfig(
|
||||||
|
agent_id="agent-1",
|
||||||
|
variant_id=None,
|
||||||
|
model_provider=provider or "",
|
||||||
|
model_name="resolved-model",
|
||||||
|
system_prompt="sys",
|
||||||
|
user_prompt_template="usr",
|
||||||
|
prompt_version="v1",
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=8192,
|
||||||
|
context_window=0,
|
||||||
|
input_token_limit=0,
|
||||||
|
token_budget=0,
|
||||||
|
timeout_seconds=60,
|
||||||
|
max_retries=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# 9.1 — Factory routing property
|
||||||
|
# **Validates: Requirements 3.4, 3.5, 9.5**
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@given(st.sampled_from(["ollama", "vllm", "", None]))
|
||||||
|
@settings(max_examples=100)
|
||||||
|
def test_factory_routing_property(provider: str | None):
|
||||||
|
"""For all model_provider in {"ollama", "vllm", "", None}, factory returns correct client type.
|
||||||
|
|
||||||
|
**Validates: Requirements 3.4, 3.5, 9.5**
|
||||||
|
"""
|
||||||
|
resolved = _make_resolved(provider)
|
||||||
|
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
|
||||||
|
client = build_llm_client(
|
||||||
|
resolved, _make_ollama_config(), _make_vllm_config(), http_client=http
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider == "vllm":
|
||||||
|
assert isinstance(client, VLLMClient), (
|
||||||
|
f"Expected VLLMClient for provider={provider!r}, got {type(client).__name__}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# "ollama", "", None all map to OllamaClient
|
||||||
|
assert isinstance(client, OllamaClient), (
|
||||||
|
f"Expected OllamaClient for provider={provider!r}, got {type(client).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# 9.2 — Error string format consistency property
|
||||||
|
# **Validates: Requirements 5.6**
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@given(st.integers(min_value=100, max_value=599))
|
||||||
|
@settings(max_examples=100)
|
||||||
|
def test_is_retryable_consistency_property(status_code: int):
|
||||||
|
"""For all HTTP status codes (100-599), _is_retryable() classifies them consistently.
|
||||||
|
|
||||||
|
Non-retryable: 400, 401, 403, 404, 422.
|
||||||
|
All other http_{code} errors are retryable.
|
||||||
|
|
||||||
|
**Validates: Requirements 5.6**
|
||||||
|
"""
|
||||||
|
error_str = f"http_{status_code}"
|
||||||
|
result = _is_retryable(error_str)
|
||||||
|
|
||||||
|
non_retryable_codes = {400, 401, 403, 404, 422}
|
||||||
|
|
||||||
|
if status_code in non_retryable_codes:
|
||||||
|
assert result is False, (
|
||||||
|
f"http_{status_code} should be non-retryable but _is_retryable returned True"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert result is True, (
|
||||||
|
f"http_{status_code} should be retryable but _is_retryable returned False"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# 9.3 — VLLMClient request payload structure property
|
||||||
|
# **Validates: Requirements 2.1, 8.1**
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@given(
|
||||||
|
system=st.text(min_size=1),
|
||||||
|
user=st.text(min_size=1),
|
||||||
|
)
|
||||||
|
@settings(max_examples=100)
|
||||||
|
def test_vllm_payload_structure_property(system: str, user: str):
|
||||||
|
"""For all generated prompt dicts, payload contains required OpenAI fields and excludes Ollama-specific fields.
|
||||||
|
|
||||||
|
**Validates: Requirements 2.1, 8.1**
|
||||||
|
"""
|
||||||
|
prompts = {"system": system, "user": user}
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
captured["payload"] = json.loads(request.content)
|
||||||
|
body = {
|
||||||
|
"choices": [
|
||||||
|
{"message": {"role": "assistant", "content": "{}"}}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return httpx.Response(200, json=body)
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
config = _make_vllm_config()
|
||||||
|
client = VLLMClient(config, http_client=http)
|
||||||
|
|
||||||
|
asyncio.run(client.call_llm(prompts, {}))
|
||||||
|
|
||||||
|
payload = captured["payload"]
|
||||||
|
|
||||||
|
# Required OpenAI fields must be present
|
||||||
|
assert "model" in payload, "Payload missing 'model' field"
|
||||||
|
assert "messages" in payload, "Payload missing 'messages' field"
|
||||||
|
assert "max_tokens" in payload, "Payload missing 'max_tokens' field"
|
||||||
|
assert "temperature" in payload, "Payload missing 'temperature' field"
|
||||||
|
|
||||||
|
# Messages must have system and user roles
|
||||||
|
roles = [m["role"] for m in payload["messages"]]
|
||||||
|
assert "system" in roles, "Messages missing 'system' role"
|
||||||
|
assert "user" in roles, "Messages missing 'user' role"
|
||||||
|
|
||||||
|
# Ollama-specific fields must NOT be present
|
||||||
|
assert "think" not in payload, "Payload contains Ollama-specific 'think' field"
|
||||||
|
assert "stream" not in payload, "Payload contains Ollama-specific 'stream' field"
|
||||||
|
assert "options" not in payload, "Payload contains Ollama-specific 'options' field"
|
||||||
|
|
||||||
|
# No nested Ollama options
|
||||||
|
for key in ("num_ctx", "num_predict"):
|
||||||
|
assert key not in payload, f"Payload contains Ollama-specific '{key}' field"
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# 9.4 — JSON repair idempotence property
|
||||||
|
# **Validates: Requirements 2.4**
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@given(
|
||||||
|
st.one_of(
|
||||||
|
st.dictionaries(st.text(max_size=20), st.text(max_size=50), max_size=5),
|
||||||
|
st.lists(st.integers(), max_size=10),
|
||||||
|
st.text(max_size=50),
|
||||||
|
st.integers(),
|
||||||
|
st.floats(allow_nan=False, allow_infinity=False),
|
||||||
|
st.booleans(),
|
||||||
|
st.none(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@settings(max_examples=100)
|
||||||
|
def test_json_repair_idempotence_property(value):
|
||||||
|
"""For all valid JSON strings, _repair_json() is idempotent.
|
||||||
|
|
||||||
|
_repair_json(_repair_json(json_str)) == _repair_json(json_str)
|
||||||
|
|
||||||
|
**Validates: Requirements 2.4**
|
||||||
|
"""
|
||||||
|
json_str = json.dumps(value)
|
||||||
|
|
||||||
|
repaired_once = _repair_json(json_str)
|
||||||
|
repaired_twice = _repair_json(repaired_once)
|
||||||
|
|
||||||
|
assert repaired_once == repaired_twice, (
|
||||||
|
f"_repair_json is not idempotent: "
|
||||||
|
f"first={repaired_once!r}, second={repaired_twice!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The repaired output should be valid JSON
|
||||||
|
json.loads(repaired_once)
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# 9.5 — Markdown fence stripping round-trip property
|
||||||
|
# **Validates: Requirements 2.3**
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@given(st.text())
|
||||||
|
@settings(max_examples=100)
|
||||||
|
def test_markdown_fence_stripping_roundtrip_property(s: str):
|
||||||
|
"""For all strings, wrapping in fences then stripping recovers the original.
|
||||||
|
|
||||||
|
The regex trims leading/trailing whitespace around the content inside
|
||||||
|
fences, so the round-trip recovers ``s.strip()``.
|
||||||
|
|
||||||
|
**Validates: Requirements 2.3**
|
||||||
|
"""
|
||||||
|
fenced = f"```json\n{s}\n```"
|
||||||
|
stripped = _strip_markdown_fences(fenced)
|
||||||
|
|
||||||
|
assert stripped == s.strip(), (
|
||||||
|
f"Round-trip failed: original={s!r}, stripped={stripped!r}, expected={s.strip()!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Identity: when no fences are present, the string is returned as-is
|
||||||
|
# (only test strings that don't look like fenced blocks themselves)
|
||||||
|
if not s.strip().startswith("```"):
|
||||||
|
assert _strip_markdown_fences(s) == s
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# 9.6 — VLLMConfig defaults property
|
||||||
|
# **Validates: Requirements 3.1**
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@settings(max_examples=100)
|
||||||
|
@given(st.just(None))
|
||||||
|
def test_vllm_config_defaults_property(_):
|
||||||
|
"""For all default-constructed instances, invariants hold.
|
||||||
|
|
||||||
|
timeout > 0, max_retries >= 0, 0 <= temperature <= 2, max_tokens > 0.
|
||||||
|
|
||||||
|
**Validates: Requirements 3.1**
|
||||||
|
"""
|
||||||
|
config = VLLMConfig()
|
||||||
|
|
||||||
|
assert config.timeout > 0, f"timeout must be > 0, got {config.timeout}"
|
||||||
|
assert config.max_retries >= 0, f"max_retries must be >= 0, got {config.max_retries}"
|
||||||
|
assert 0 <= config.temperature <= 2, (
|
||||||
|
f"temperature must be in [0, 2], got {config.temperature}"
|
||||||
|
)
|
||||||
|
assert config.max_tokens > 0, f"max_tokens must be > 0, got {config.max_tokens}"
|
||||||
|
assert config.base_url, "base_url must be non-empty"
|
||||||
|
assert config.model, "model must be non-empty"
|
||||||
@@ -0,0 +1,461 @@
|
|||||||
|
"""Tests for the vLLM client, health check, config, and LLM factory."""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from services.extractor.client import OllamaClient
|
||||||
|
from services.extractor.llm_factory import build_llm_client
|
||||||
|
from services.extractor.vllm_client import VLLMClient, check_vllm_health
|
||||||
|
from services.shared.agent_config import ResolvedAgentConfig
|
||||||
|
from services.shared.config import AppConfig, OllamaConfig, VLLMConfig, load_config
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_extraction_json() -> str:
|
||||||
|
"""Minimal valid extraction result as JSON string."""
|
||||||
|
return json.dumps({
|
||||||
|
"summary": "Apple beat earnings expectations.",
|
||||||
|
"companies": [
|
||||||
|
{
|
||||||
|
"ticker": "AAPL",
|
||||||
|
"company_name": "Apple Inc.",
|
||||||
|
"relevance": 0.95,
|
||||||
|
"sentiment": "positive",
|
||||||
|
"impact_score": 0.7,
|
||||||
|
"impact_horizon": "1d_30d",
|
||||||
|
"catalyst_type": "earnings",
|
||||||
|
"key_facts": ["Revenue up 12%"],
|
||||||
|
"risks": [],
|
||||||
|
"evidence_spans": ["Apple beat expectations"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"macro_themes": ["ai_capex"],
|
||||||
|
"novelty_score": 0.6,
|
||||||
|
"confidence": 0.85,
|
||||||
|
"extraction_warnings": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_response(content: str, status: int = 200) -> httpx.Response:
|
||||||
|
"""Build a fake OpenAI-compatible /v1/chat/completions response."""
|
||||||
|
body = {
|
||||||
|
"choices": [
|
||||||
|
{"message": {"role": "assistant", "content": content}}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return httpx.Response(status, json=body)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vllm_config() -> VLLMConfig:
|
||||||
|
return VLLMConfig(
|
||||||
|
base_url="http://test-vllm:8000",
|
||||||
|
model="test-vllm-model",
|
||||||
|
timeout=10,
|
||||||
|
max_retries=2,
|
||||||
|
retry_base_delay=0.0,
|
||||||
|
retry_max_delay=0.0,
|
||||||
|
retry_backoff_multiplier=2.0,
|
||||||
|
max_tokens=4096,
|
||||||
|
temperature=0.7,
|
||||||
|
api_key="",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ollama_config() -> OllamaConfig:
|
||||||
|
return OllamaConfig(
|
||||||
|
base_url="http://test-ollama:11434",
|
||||||
|
model="test-ollama-model",
|
||||||
|
timeout=10,
|
||||||
|
retry_base_delay=0.0,
|
||||||
|
retry_max_delay=0.0,
|
||||||
|
retry_backoff_multiplier=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_prompts() -> dict[str, str]:
|
||||||
|
return {"system": "You are a helpful assistant.", "user": "Extract info."}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_resolved(provider: str = "vllm") -> ResolvedAgentConfig:
|
||||||
|
return ResolvedAgentConfig(
|
||||||
|
agent_id="agent-1",
|
||||||
|
variant_id=None,
|
||||||
|
model_provider=provider,
|
||||||
|
model_name="resolved-model",
|
||||||
|
system_prompt="sys",
|
||||||
|
user_prompt_template="usr",
|
||||||
|
prompt_version="v1",
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=8192,
|
||||||
|
context_window=0,
|
||||||
|
input_token_limit=0,
|
||||||
|
token_budget=0,
|
||||||
|
timeout_seconds=60,
|
||||||
|
max_retries=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Task 7: Unit Tests for VLLMClient
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
|
# 7.1 — VLLMClient sends correct payload to /v1/chat/completions
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_sends_correct_payload():
|
||||||
|
"""VLLMClient sends POST to /v1/chat/completions with correct OpenAI payload."""
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
captured["url"] = str(request.url)
|
||||||
|
captured["payload"] = json.loads(request.content)
|
||||||
|
return _openai_response(_valid_extraction_json())
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
config = _make_vllm_config()
|
||||||
|
client = VLLMClient(config, http_client=http)
|
||||||
|
|
||||||
|
await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
assert captured["url"] == "http://test-vllm:8000/v1/chat/completions"
|
||||||
|
payload = captured["payload"]
|
||||||
|
assert payload["model"] == "test-vllm-model"
|
||||||
|
assert len(payload["messages"]) == 2
|
||||||
|
assert payload["messages"][0]["role"] == "system"
|
||||||
|
assert payload["messages"][1]["role"] == "user"
|
||||||
|
assert payload["max_tokens"] == 4096
|
||||||
|
assert payload["temperature"] == 0.7
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.2 — VLLMClient extracts content from choices[0].message.content
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_extracts_content_from_choices():
|
||||||
|
"""VLLMClient extracts content from choices[0].message.content."""
|
||||||
|
transport = httpx.MockTransport(
|
||||||
|
lambda req: _openai_response(_valid_extraction_json())
|
||||||
|
)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
attempt = await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
assert attempt.raw_output == _valid_extraction_json()
|
||||||
|
assert attempt.error is None
|
||||||
|
assert attempt.validation is not None
|
||||||
|
assert attempt.validation.valid
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.3 — VLLMClient handles empty choices array → empty_model_response
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_empty_choices():
|
||||||
|
"""Empty choices array returns empty_model_response error."""
|
||||||
|
body = {"choices": []}
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(
|
||||||
|
lambda req: httpx.Response(200, json=body)
|
||||||
|
)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
attempt = await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
assert attempt.error == "empty_model_response"
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.4 — VLLMClient handles HTTP timeout → timeout error
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_timeout():
|
||||||
|
"""HTTP timeout returns 'timeout' error."""
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
raise httpx.ReadTimeout("timed out")
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
attempt = await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
assert attempt.error == "timeout"
|
||||||
|
assert attempt.duration_ms >= 0
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.5 — VLLMClient handles HTTP 500 → http_500 retryable error
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_http_500():
|
||||||
|
"""HTTP 500 returns 'http_500' error marked as retryable."""
|
||||||
|
transport = httpx.MockTransport(
|
||||||
|
lambda req: httpx.Response(500, text="Internal Server Error")
|
||||||
|
)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
attempt = await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
assert attempt.error == "http_500"
|
||||||
|
assert attempt.retryable is True
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.6 — VLLMClient handles HTTP 400 → http_400 non-retryable error
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_http_400():
|
||||||
|
"""HTTP 400 returns 'http_400' error marked as non-retryable."""
|
||||||
|
transport = httpx.MockTransport(
|
||||||
|
lambda req: httpx.Response(400, text="Bad Request")
|
||||||
|
)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
attempt = await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
assert attempt.error == "http_400"
|
||||||
|
assert attempt.retryable is False
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.7 — VLLMClient handles connection error → connection_error: ...
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_connection_error():
|
||||||
|
"""Connection error returns 'connection_error: ...' error string."""
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
attempt = await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
assert attempt.error is not None
|
||||||
|
assert attempt.error.startswith("connection_error:")
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.8 — VLLMClient applies markdown fence stripping and JSON repair
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_markdown_fence_stripping_and_json_repair():
|
||||||
|
"""VLLMClient strips markdown fences and repairs JSON."""
|
||||||
|
# Wrap valid JSON in markdown fences
|
||||||
|
fenced = f"```json\n{_valid_extraction_json()}\n```"
|
||||||
|
transport = httpx.MockTransport(
|
||||||
|
lambda req: _openai_response(fenced)
|
||||||
|
)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
client = VLLMClient(_make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
attempt = await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
# Should succeed after stripping fences
|
||||||
|
assert attempt.error is None
|
||||||
|
assert attempt.validation is not None
|
||||||
|
assert attempt.validation.valid
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.9 — VLLMClient includes temperature and response_format in payload
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_vllm_payload_includes_temperature_and_response_format():
|
||||||
|
"""Payload includes temperature and response_format fields."""
|
||||||
|
captured: dict = {}
|
||||||
|
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
captured["payload"] = json.loads(request.content)
|
||||||
|
return _openai_response(_valid_extraction_json())
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
config = _make_vllm_config()
|
||||||
|
config.temperature = 0.3
|
||||||
|
client = VLLMClient(config, http_client=http)
|
||||||
|
|
||||||
|
await client.call_llm(_make_prompts(), {})
|
||||||
|
|
||||||
|
assert captured["payload"]["temperature"] == 0.3
|
||||||
|
assert captured["payload"]["response_format"] == {"type": "json_object"}
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.10 — Health check success returns True and logs INFO
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_success(caplog):
|
||||||
|
"""check_vllm_health returns True and logs INFO on success."""
|
||||||
|
transport = httpx.MockTransport(
|
||||||
|
lambda req: httpx.Response(200, json={"data": [{"id": "model-1"}]})
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)):
|
||||||
|
with caplog.at_level(logging.INFO, logger="vllm_client"):
|
||||||
|
result = await check_vllm_health("http://test-vllm:8000")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert any("health check passed" in r.message for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
# 7.11 — Health check failure returns False and logs WARNING
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_failure(caplog):
|
||||||
|
"""check_vllm_health returns False and logs WARNING on failure."""
|
||||||
|
def handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
|
||||||
|
with patch("services.extractor.vllm_client.httpx.AsyncClient", return_value=httpx.AsyncClient(transport=transport)):
|
||||||
|
with caplog.at_level(logging.WARNING, logger="vllm_client"):
|
||||||
|
result = await check_vllm_health("http://unreachable:8000")
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert any("health check failed" in r.message for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
# 7.12 — OllamaClient.call_llm() delegates to _call_ollama()
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ollama_call_llm_delegates():
|
||||||
|
"""OllamaClient.call_llm() delegates to _call_ollama()."""
|
||||||
|
transport = httpx.MockTransport(
|
||||||
|
lambda req: httpx.Response(
|
||||||
|
200,
|
||||||
|
json={"message": {"role": "assistant", "content": _valid_extraction_json()}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
config = _make_ollama_config()
|
||||||
|
client = OllamaClient(config, http_client=http)
|
||||||
|
|
||||||
|
prompts = _make_prompts()
|
||||||
|
schema = {}
|
||||||
|
|
||||||
|
# call_llm should produce the same result as _call_ollama
|
||||||
|
result_llm = await client.call_llm(prompts, schema)
|
||||||
|
# Both should succeed with valid extraction JSON
|
||||||
|
assert result_llm.error is None
|
||||||
|
assert result_llm.validation is not None
|
||||||
|
assert result_llm.validation.valid
|
||||||
|
assert result_llm.model == config.model
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 7.13 — VLLMConfig loading from environment variables
|
||||||
|
def test_vllm_config_from_env(monkeypatch):
|
||||||
|
"""VLLMConfig fields are loaded from VLLM_* environment variables."""
|
||||||
|
monkeypatch.setenv("VLLM_BASE_URL", "http://custom:9000")
|
||||||
|
monkeypatch.setenv("VLLM_MODEL", "custom-model")
|
||||||
|
monkeypatch.setenv("VLLM_TIMEOUT", "300")
|
||||||
|
monkeypatch.setenv("VLLM_MAX_RETRIES", "5")
|
||||||
|
monkeypatch.setenv("VLLM_TEMPERATURE", "0.9")
|
||||||
|
monkeypatch.setenv("VLLM_API_KEY", "secret-key")
|
||||||
|
monkeypatch.setenv("VLLM_MAX_TOKENS", "16384")
|
||||||
|
|
||||||
|
cfg = load_config()
|
||||||
|
|
||||||
|
assert cfg.vllm.base_url == "http://custom:9000"
|
||||||
|
assert cfg.vllm.model == "custom-model"
|
||||||
|
assert cfg.vllm.timeout == 300
|
||||||
|
assert cfg.vllm.max_retries == 5
|
||||||
|
assert cfg.vllm.temperature == 0.9
|
||||||
|
assert cfg.vllm.api_key == "secret-key"
|
||||||
|
assert cfg.vllm.max_tokens == 16384
|
||||||
|
|
||||||
|
|
||||||
|
# 7.14 — AppConfig includes vllm field with correct defaults
|
||||||
|
def test_appconfig_vllm_defaults():
|
||||||
|
"""AppConfig includes a vllm field with VLLMConfig defaults."""
|
||||||
|
cfg = AppConfig()
|
||||||
|
|
||||||
|
assert hasattr(cfg, "vllm")
|
||||||
|
assert isinstance(cfg.vllm, VLLMConfig)
|
||||||
|
assert cfg.vllm.base_url == "http://192.168.42.254:8000"
|
||||||
|
assert cfg.vllm.model == "RedHatAI/Qwen3.6-35B-A3B-NVFP4"
|
||||||
|
assert cfg.vllm.timeout == 120
|
||||||
|
assert cfg.vllm.max_retries == 2
|
||||||
|
assert cfg.vllm.temperature == 0.7
|
||||||
|
assert cfg.vllm.max_tokens == 32768
|
||||||
|
assert cfg.vllm.api_key == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ===================================================================
|
||||||
|
# Task 8: Unit Tests for LLM Factory
|
||||||
|
# ===================================================================
|
||||||
|
|
||||||
|
|
||||||
|
# 8.1 — Factory returns OllamaClient when provider is "ollama"
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_factory_ollama_provider():
|
||||||
|
"""build_llm_client returns OllamaClient when provider is 'ollama'."""
|
||||||
|
resolved = _make_resolved(provider="ollama")
|
||||||
|
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
|
||||||
|
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
assert isinstance(client, OllamaClient)
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 8.2 — Factory returns VLLMClient when provider is "vllm"
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_factory_vllm_provider():
|
||||||
|
"""build_llm_client returns VLLMClient when provider is 'vllm'."""
|
||||||
|
resolved = _make_resolved(provider="vllm")
|
||||||
|
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
|
||||||
|
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
assert isinstance(client, VLLMClient)
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 8.3 — Factory returns OllamaClient when provider is empty string (default)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_factory_empty_provider_defaults_to_ollama():
|
||||||
|
"""build_llm_client returns OllamaClient when provider is empty string."""
|
||||||
|
resolved = _make_resolved(provider="")
|
||||||
|
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
|
||||||
|
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
assert isinstance(client, OllamaClient)
|
||||||
|
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# 8.4 — Factory returns OllamaClient with warning when provider is unknown
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_factory_unknown_provider_warns_and_falls_back(caplog):
|
||||||
|
"""build_llm_client logs warning and returns OllamaClient for unknown provider."""
|
||||||
|
resolved = _make_resolved(provider="unknown-provider")
|
||||||
|
transport = httpx.MockTransport(lambda req: httpx.Response(200))
|
||||||
|
http = httpx.AsyncClient(transport=transport)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
client = build_llm_client(resolved, _make_ollama_config(), _make_vllm_config(), http_client=http)
|
||||||
|
|
||||||
|
assert isinstance(client, OllamaClient)
|
||||||
|
assert any("unknown" in r.message.lower() for r in caplog.records)
|
||||||
|
|
||||||
|
await client.close()
|
||||||
Reference in New Issue
Block a user