Files
stonks-oracle/tests/test_pbt_llm_provider.py
T
Celes Renata 117b693b19 feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
2026-04-23 08:17:23 +00:00

297 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Property-based tests for the LLM provider abstraction layer.
Feature: remote-vllm-support
Uses Hypothesis to validate correctness properties of the provider
abstraction: factory routing, error classification consistency,
VLLMClient payload structure, JSON repair idempotence, markdown
fence stripping round-trip, and VLLMConfig default invariants.
Requirements: 2.1, 2.3, 2.4, 3.1, 3.4, 3.5, 5.6, 8.1, 9.5
Design: Correctness Properties P1P6
"""
from __future__ import annotations
import asyncio
import json
import httpx
from hypothesis import given, settings
from hypothesis import strategies as st
from services.extractor.client import (
OllamaClient,
_is_retryable,
_repair_json,
_strip_markdown_fences,
)
from services.extractor.llm_factory import build_llm_client
from services.extractor.vllm_client import VLLMClient
from services.shared.agent_config import ResolvedAgentConfig
from services.shared.config import OllamaConfig, VLLMConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ollama_config() -> OllamaConfig:
return OllamaConfig(
base_url="http://test-ollama:11434",
model="test-ollama-model",
timeout=10,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
)
def _make_vllm_config() -> VLLMConfig:
return VLLMConfig(
base_url="http://test-vllm:8000",
model="test-vllm-model",
timeout=10,
max_retries=2,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
max_tokens=4096,
temperature=0.7,
api_key="",
)
def _make_resolved(provider: str | None) -> ResolvedAgentConfig:
return ResolvedAgentConfig(
agent_id="agent-1",
variant_id=None,
model_provider=provider or "",
model_name="resolved-model",
system_prompt="sys",
user_prompt_template="usr",
prompt_version="v1",
temperature=0.5,
max_tokens=8192,
context_window=0,
input_token_limit=0,
token_budget=0,
timeout_seconds=60,
max_retries=2,
)
# ===================================================================
# 9.1 — Factory routing property
# **Validates: Requirements 3.4, 3.5, 9.5**
# ===================================================================
@given(st.sampled_from(["ollama", "vllm", "", None]))
@settings(max_examples=100)
def test_factory_routing_property(provider: str | None):
"""For all model_provider in {"ollama", "vllm", "", None}, factory returns correct client type.
**Validates: Requirements 3.4, 3.5, 9.5**
"""
resolved = _make_resolved(provider)
transport = httpx.MockTransport(lambda req: httpx.Response(200))
http = httpx.AsyncClient(transport=transport)
client = build_llm_client(
resolved, _make_ollama_config(), _make_vllm_config(), http_client=http
)
if provider == "vllm":
assert isinstance(client, VLLMClient), (
f"Expected VLLMClient for provider={provider!r}, got {type(client).__name__}"
)
else:
# "ollama", "", None all map to OllamaClient
assert isinstance(client, OllamaClient), (
f"Expected OllamaClient for provider={provider!r}, got {type(client).__name__}"
)
# ===================================================================
# 9.2 — Error string format consistency property
# **Validates: Requirements 5.6**
# ===================================================================
@given(st.integers(min_value=100, max_value=599))
@settings(max_examples=100)
def test_is_retryable_consistency_property(status_code: int):
"""For all HTTP status codes (100-599), _is_retryable() classifies them consistently.
Non-retryable: 400, 401, 403, 404, 422.
All other http_{code} errors are retryable.
**Validates: Requirements 5.6**
"""
error_str = f"http_{status_code}"
result = _is_retryable(error_str)
non_retryable_codes = {400, 401, 403, 404, 422}
if status_code in non_retryable_codes:
assert result is False, (
f"http_{status_code} should be non-retryable but _is_retryable returned True"
)
else:
assert result is True, (
f"http_{status_code} should be retryable but _is_retryable returned False"
)
# ===================================================================
# 9.3 — VLLMClient request payload structure property
# **Validates: Requirements 2.1, 8.1**
# ===================================================================
@given(
system=st.text(min_size=1),
user=st.text(min_size=1),
)
@settings(max_examples=100)
def test_vllm_payload_structure_property(system: str, user: str):
"""For all generated prompt dicts, payload contains required OpenAI fields and excludes Ollama-specific fields.
**Validates: Requirements 2.1, 8.1**
"""
prompts = {"system": system, "user": user}
captured: dict = {}
def handler(request: httpx.Request) -> httpx.Response:
captured["payload"] = json.loads(request.content)
body = {
"choices": [
{"message": {"role": "assistant", "content": "{}"}}
],
}
return httpx.Response(200, json=body)
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
config = _make_vllm_config()
client = VLLMClient(config, http_client=http)
asyncio.run(client.call_llm(prompts, {}))
payload = captured["payload"]
# Required OpenAI fields must be present
assert "model" in payload, "Payload missing 'model' field"
assert "messages" in payload, "Payload missing 'messages' field"
assert "max_tokens" in payload, "Payload missing 'max_tokens' field"
assert "temperature" in payload, "Payload missing 'temperature' field"
# Messages must have system and user roles
roles = [m["role"] for m in payload["messages"]]
assert "system" in roles, "Messages missing 'system' role"
assert "user" in roles, "Messages missing 'user' role"
# Ollama-specific fields must NOT be present
assert "think" not in payload, "Payload contains Ollama-specific 'think' field"
assert "stream" not in payload, "Payload contains Ollama-specific 'stream' field"
assert "options" not in payload, "Payload contains Ollama-specific 'options' field"
# No nested Ollama options
for key in ("num_ctx", "num_predict"):
assert key not in payload, f"Payload contains Ollama-specific '{key}' field"
# ===================================================================
# 9.4 — JSON repair idempotence property
# **Validates: Requirements 2.4**
# ===================================================================
@given(
st.one_of(
st.dictionaries(st.text(max_size=20), st.text(max_size=50), max_size=5),
st.lists(st.integers(), max_size=10),
st.text(max_size=50),
st.integers(),
st.floats(allow_nan=False, allow_infinity=False),
st.booleans(),
st.none(),
)
)
@settings(max_examples=100)
def test_json_repair_idempotence_property(value):
"""For all valid JSON strings, _repair_json() is idempotent.
_repair_json(_repair_json(json_str)) == _repair_json(json_str)
**Validates: Requirements 2.4**
"""
json_str = json.dumps(value)
repaired_once = _repair_json(json_str)
repaired_twice = _repair_json(repaired_once)
assert repaired_once == repaired_twice, (
f"_repair_json is not idempotent: "
f"first={repaired_once!r}, second={repaired_twice!r}"
)
# The repaired output should be valid JSON
json.loads(repaired_once)
# ===================================================================
# 9.5 — Markdown fence stripping round-trip property
# **Validates: Requirements 2.3**
# ===================================================================
@given(st.text())
@settings(max_examples=100)
def test_markdown_fence_stripping_roundtrip_property(s: str):
"""For all strings, wrapping in fences then stripping recovers the original.
The regex trims leading/trailing whitespace around the content inside
fences, so the round-trip recovers ``s.strip()``.
**Validates: Requirements 2.3**
"""
fenced = f"```json\n{s}\n```"
stripped = _strip_markdown_fences(fenced)
assert stripped == s.strip(), (
f"Round-trip failed: original={s!r}, stripped={stripped!r}, expected={s.strip()!r}"
)
# Identity: when no fences are present, the string is returned as-is
# (only test strings that don't look like fenced blocks themselves)
if not s.strip().startswith("```"):
assert _strip_markdown_fences(s) == s
# ===================================================================
# 9.6 — VLLMConfig defaults property
# **Validates: Requirements 3.1**
# ===================================================================
@settings(max_examples=100)
@given(st.just(None))
def test_vllm_config_defaults_property(_):
"""For all default-constructed instances, invariants hold.
timeout > 0, max_retries >= 0, 0 <= temperature <= 2, max_tokens > 0.
**Validates: Requirements 3.1**
"""
config = VLLMConfig()
assert config.timeout > 0, f"timeout must be > 0, got {config.timeout}"
assert config.max_retries >= 0, f"max_retries must be >= 0, got {config.max_retries}"
assert 0 <= config.temperature <= 2, (
f"temperature must be in [0, 2], got {config.temperature}"
)
assert config.max_tokens > 0, f"max_tokens must be > 0, got {config.max_tokens}"
assert config.base_url, "base_url must be non-empty"
assert config.model, "model must be non-empty"