"""Property-based tests for the LLM provider abstraction layer. Feature: remote-vllm-support Uses Hypothesis to validate correctness properties of the provider abstraction: factory routing, error classification consistency, VLLMClient payload structure, JSON repair idempotence, markdown fence stripping round-trip, and VLLMConfig default invariants. Requirements: 2.1, 2.3, 2.4, 3.1, 3.4, 3.5, 5.6, 8.1, 9.5 Design: Correctness Properties P1–P6 """ from __future__ import annotations import asyncio import json import httpx from hypothesis import given, settings from hypothesis import strategies as st from services.extractor.client import ( OllamaClient, _is_retryable, _repair_json, _strip_markdown_fences, ) from services.extractor.llm_factory import build_llm_client from services.extractor.vllm_client import VLLMClient from services.shared.agent_config import ResolvedAgentConfig from services.shared.config import OllamaConfig, VLLMConfig # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_ollama_config() -> OllamaConfig: return OllamaConfig( base_url="http://test-ollama:11434", model="test-ollama-model", timeout=10, retry_base_delay=0.0, retry_max_delay=0.0, retry_backoff_multiplier=2.0, ) def _make_vllm_config() -> VLLMConfig: return VLLMConfig( base_url="http://test-vllm:8000", model="test-vllm-model", timeout=10, max_retries=2, retry_base_delay=0.0, retry_max_delay=0.0, retry_backoff_multiplier=2.0, max_tokens=4096, temperature=0.7, api_key="", ) def _make_resolved(provider: str | None) -> ResolvedAgentConfig: return ResolvedAgentConfig( agent_id="agent-1", variant_id=None, model_provider=provider or "", model_name="resolved-model", system_prompt="sys", user_prompt_template="usr", prompt_version="v1", temperature=0.5, max_tokens=8192, context_window=0, input_token_limit=0, token_budget=0, timeout_seconds=60, max_retries=2, ) # =================================================================== # 9.1 — Factory routing property # **Validates: Requirements 3.4, 3.5, 9.5** # =================================================================== @given(st.sampled_from(["ollama", "vllm", "", None])) @settings(max_examples=100) def test_factory_routing_property(provider: str | None): """For all model_provider in {"ollama", "vllm", "", None}, factory returns correct client type. **Validates: Requirements 3.4, 3.5, 9.5** """ resolved = _make_resolved(provider) transport = httpx.MockTransport(lambda req: httpx.Response(200)) http = httpx.AsyncClient(transport=transport) client = build_llm_client( resolved, _make_ollama_config(), _make_vllm_config(), http_client=http ) if provider == "vllm": assert isinstance(client, VLLMClient), ( f"Expected VLLMClient for provider={provider!r}, got {type(client).__name__}" ) else: # "ollama", "", None all map to OllamaClient assert isinstance(client, OllamaClient), ( f"Expected OllamaClient for provider={provider!r}, got {type(client).__name__}" ) # =================================================================== # 9.2 — Error string format consistency property # **Validates: Requirements 5.6** # =================================================================== @given(st.integers(min_value=100, max_value=599)) @settings(max_examples=100) def test_is_retryable_consistency_property(status_code: int): """For all HTTP status codes (100-599), _is_retryable() classifies them consistently. Non-retryable: 400, 401, 403, 404, 422. All other http_{code} errors are retryable. **Validates: Requirements 5.6** """ error_str = f"http_{status_code}" result = _is_retryable(error_str) non_retryable_codes = {400, 401, 403, 404, 422} if status_code in non_retryable_codes: assert result is False, ( f"http_{status_code} should be non-retryable but _is_retryable returned True" ) else: assert result is True, ( f"http_{status_code} should be retryable but _is_retryable returned False" ) # =================================================================== # 9.3 — VLLMClient request payload structure property # **Validates: Requirements 2.1, 8.1** # =================================================================== @given( system=st.text(min_size=1), user=st.text(min_size=1), ) @settings(max_examples=100) def test_vllm_payload_structure_property(system: str, user: str): """For all generated prompt dicts, payload contains required OpenAI fields and excludes Ollama-specific fields. **Validates: Requirements 2.1, 8.1** """ prompts = {"system": system, "user": user} captured: dict = {} def handler(request: httpx.Request) -> httpx.Response: captured["payload"] = json.loads(request.content) body = { "choices": [ {"message": {"role": "assistant", "content": "{}"}} ], } return httpx.Response(200, json=body) transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) config = _make_vllm_config() client = VLLMClient(config, http_client=http) asyncio.run(client.call_llm(prompts, {})) payload = captured["payload"] # Required OpenAI fields must be present assert "model" in payload, "Payload missing 'model' field" assert "messages" in payload, "Payload missing 'messages' field" assert "max_tokens" in payload, "Payload missing 'max_tokens' field" assert "temperature" in payload, "Payload missing 'temperature' field" # Messages must have system and user roles roles = [m["role"] for m in payload["messages"]] assert "system" in roles, "Messages missing 'system' role" assert "user" in roles, "Messages missing 'user' role" # Ollama-specific fields must NOT be present assert "think" not in payload, "Payload contains Ollama-specific 'think' field" assert "stream" not in payload, "Payload contains Ollama-specific 'stream' field" assert "options" not in payload, "Payload contains Ollama-specific 'options' field" # No nested Ollama options for key in ("num_ctx", "num_predict"): assert key not in payload, f"Payload contains Ollama-specific '{key}' field" # =================================================================== # 9.4 — JSON repair idempotence property # **Validates: Requirements 2.4** # =================================================================== @given( st.one_of( st.dictionaries(st.text(max_size=20), st.text(max_size=50), max_size=5), st.lists(st.integers(), max_size=10), st.text(max_size=50), st.integers(), st.floats(allow_nan=False, allow_infinity=False), st.booleans(), st.none(), ) ) @settings(max_examples=100) def test_json_repair_idempotence_property(value): """For all valid JSON strings, _repair_json() is idempotent. _repair_json(_repair_json(json_str)) == _repair_json(json_str) **Validates: Requirements 2.4** """ json_str = json.dumps(value) repaired_once = _repair_json(json_str) repaired_twice = _repair_json(repaired_once) assert repaired_once == repaired_twice, ( f"_repair_json is not idempotent: " f"first={repaired_once!r}, second={repaired_twice!r}" ) # The repaired output should be valid JSON json.loads(repaired_once) # =================================================================== # 9.5 — Markdown fence stripping round-trip property # **Validates: Requirements 2.3** # =================================================================== @given(st.text()) @settings(max_examples=100) def test_markdown_fence_stripping_roundtrip_property(s: str): """For all strings, wrapping in fences then stripping recovers the original. The regex trims leading/trailing whitespace around the content inside fences, so the round-trip recovers ``s.strip()``. **Validates: Requirements 2.3** """ fenced = f"```json\n{s}\n```" stripped = _strip_markdown_fences(fenced) assert stripped == s.strip(), ( f"Round-trip failed: original={s!r}, stripped={stripped!r}, expected={s.strip()!r}" ) # Identity: when no fences are present, the string is returned as-is # (only test strings that don't look like fenced blocks themselves) if not s.strip().startswith("```"): assert _strip_markdown_fences(s) == s # =================================================================== # 9.6 — VLLMConfig defaults property # **Validates: Requirements 3.1** # =================================================================== @settings(max_examples=100) @given(st.just(None)) def test_vllm_config_defaults_property(_): """For all default-constructed instances, invariants hold. timeout > 0, max_retries >= 0, 0 <= temperature <= 2, max_tokens > 0. **Validates: Requirements 3.1** """ config = VLLMConfig() assert config.timeout > 0, f"timeout must be > 0, got {config.timeout}" assert config.max_retries >= 0, f"max_retries must be >= 0, got {config.max_retries}" assert 0 <= config.temperature <= 2, ( f"temperature must be in [0, 2], got {config.temperature}" ) assert config.max_tokens > 0, f"max_tokens must be > 0, got {config.max_tokens}" assert config.base_url, "base_url must be non-empty" assert config.model, "model must be non-empty"