Files
stonks-oracle/tests/test_ollama_client.py
T

389 lines
12 KiB
Python

"""Tests for the Ollama client wrapper."""
import json
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from services.extractor.client import (
ExtractionResponse,
OllamaClient,
_compute_backoff,
_is_retryable,
)
from services.shared.config import OllamaConfig
def _valid_extraction_json() -> str:
"""Minimal valid extraction result as JSON string."""
return json.dumps({
"summary": "Apple beat earnings expectations.",
"companies": [
{
"ticker": "AAPL",
"company_name": "Apple Inc.",
"relevance": 0.95,
"sentiment": "positive",
"impact_score": 0.7,
"impact_horizon": "1d_30d",
"catalyst_type": "earnings",
"key_facts": ["Revenue up 12%"],
"risks": [],
"evidence_spans": ["Apple beat expectations"],
}
],
"macro_themes": ["ai_capex"],
"novelty_score": 0.6,
"confidence": 0.85,
"extraction_warnings": [],
})
def _ollama_response(content: str) -> httpx.Response:
"""Build a fake Ollama /api/chat response."""
body = {"message": {"role": "assistant", "content": content}}
return httpx.Response(200, json=body)
def _make_config() -> OllamaConfig:
return OllamaConfig(
base_url="http://test:11434",
model="test-model",
timeout=10,
retry_base_delay=0.0,
retry_max_delay=0.0,
retry_backoff_multiplier=2.0,
)
@pytest.mark.asyncio
async def test_extract_success():
"""Successful extraction on first attempt."""
transport = httpx.MockTransport(
lambda req: _ollama_response(_valid_extraction_json())
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), http_client=http)
resp = await client.extract(
document_text="Apple reported record Q4 earnings.",
document_type="article",
document_id="doc-1",
)
assert resp.success
assert resp.result is not None
assert resp.result.companies[0].ticker == "AAPL"
assert len(resp.attempts) == 1
assert resp.attempts[0].error is None
assert resp.model == "test-model"
assert resp.prompt_metadata["prompt_version"]
await client.close()
@pytest.mark.asyncio
async def test_extract_retry_on_invalid_json():
"""Client retries when model returns invalid JSON, then succeeds."""
call_count = 0
def handler(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
if call_count == 1:
return _ollama_response("not valid json {{{")
return _ollama_response(_valid_extraction_json())
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), max_retries=2, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert resp.success
assert len(resp.attempts) == 2
assert resp.attempts[0].error is not None
assert resp.attempts[1].error is None
await client.close()
@pytest.mark.asyncio
async def test_extract_all_retries_exhausted():
"""All retries fail — response indicates failure with all attempts recorded."""
transport = httpx.MockTransport(
lambda req: _ollama_response("bad output")
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), max_retries=1, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert not resp.success
assert resp.result is None
assert len(resp.attempts) == 2 # initial + 1 retry
await client.close()
@pytest.mark.asyncio
async def test_extract_http_timeout():
"""HTTP timeout is captured as an error."""
def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ReadTimeout("timed out")
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), max_retries=0, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert not resp.success
assert resp.attempts[0].error == "timeout"
await client.close()
@pytest.mark.asyncio
async def test_extract_http_500():
"""HTTP 500 is captured as an error."""
transport = httpx.MockTransport(
lambda req: httpx.Response(500, text="Internal Server Error")
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), max_retries=0, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert not resp.success
assert "500" in (resp.attempts[0].error or "")
await client.close()
@pytest.mark.asyncio
async def test_extract_empty_model_response():
"""Empty content from model is treated as an error."""
transport = httpx.MockTransport(
lambda req: _ollama_response("")
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), max_retries=0, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert not resp.success
assert resp.attempts[0].error == "empty_model_response"
await client.close()
@pytest.mark.asyncio
async def test_extract_schema_validation_failure():
"""Model returns valid JSON but missing required fields."""
bad_extraction = json.dumps({"summary": "test"}) # missing companies, etc.
transport = httpx.MockTransport(
lambda req: _ollama_response(bad_extraction)
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), max_retries=0, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert not resp.success
assert resp.attempts[0].validation is not None
assert not resp.attempts[0].validation.valid
await client.close()
@pytest.mark.asyncio
async def test_extract_with_known_tickers():
"""Known tickers are passed through to the prompt builder."""
transport = httpx.MockTransport(
lambda req: _ollama_response(_valid_extraction_json())
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), http_client=http)
resp = await client.extract(
document_text="test",
document_type="article",
known_tickers=["AAPL", "MSFT"],
)
assert resp.success
await client.close()
@pytest.mark.asyncio
async def test_extract_sends_structured_format():
"""The request payload includes the JSON schema in the format field."""
captured_payload: dict[str, object] = {}
def handler(request: httpx.Request) -> httpx.Response:
captured_payload.update(json.loads(request.content))
return _ollama_response(_valid_extraction_json())
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), http_client=http)
await client.extract(document_text="test", document_type="article")
assert "format" in captured_payload
assert isinstance(captured_payload["format"], dict)
assert captured_payload["stream"] is False
assert captured_payload["model"] == "test-model"
await client.close()
@pytest.mark.asyncio
async def test_extract_non_retryable_http_400_stops_immediately():
"""HTTP 400 is non-retryable — client stops after first attempt."""
call_count = 0
def handler(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
return httpx.Response(400, text="Bad Request")
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), max_retries=3, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert not resp.success
assert len(resp.attempts) == 1 # no retries for 400
assert resp.attempts[0].error == "http_400"
assert not resp.attempts[0].retryable
assert call_count == 1
await client.close()
@pytest.mark.asyncio
async def test_extract_retryable_http_500_retries():
"""HTTP 500 is retryable — client retries up to max_retries."""
call_count = 0
def handler(request: httpx.Request) -> httpx.Response:
nonlocal call_count
call_count += 1
if call_count <= 2:
return httpx.Response(500, text="Internal Server Error")
return _ollama_response(_valid_extraction_json())
transport = httpx.MockTransport(handler)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), max_retries=3, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert resp.success
assert len(resp.attempts) == 3
assert resp.attempts[0].retryable is True
assert resp.attempts[1].retryable is True
assert call_count == 3
await client.close()
@pytest.mark.asyncio
async def test_extract_retryable_field_on_success():
"""Successful attempt has retryable=True (default)."""
transport = httpx.MockTransport(
lambda req: _ollama_response(_valid_extraction_json())
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(_make_config(), http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert resp.success
assert resp.attempts[0].retryable is True
await client.close()
@pytest.mark.asyncio
async def test_extract_backoff_is_called_between_retries():
"""asyncio.sleep is called with increasing delays between retries."""
config = OllamaConfig(
base_url="http://test:11434",
model="test-model",
timeout=10,
retry_base_delay=1.0,
retry_max_delay=10.0,
retry_backoff_multiplier=2.0,
)
transport = httpx.MockTransport(
lambda req: _ollama_response("bad output")
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(config, max_retries=2, http_client=http)
with patch("services.extractor.client.asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
resp = await client.extract(document_text="test", document_type="article")
assert not resp.success
assert len(resp.attempts) == 3 # initial + 2 retries
assert mock_sleep.call_count == 2
# First backoff: 1.0 * 2^0 = 1.0
assert mock_sleep.call_args_list[0].args[0] == pytest.approx(1.0)
# Second backoff: 1.0 * 2^1 = 2.0
assert mock_sleep.call_args_list[1].args[0] == pytest.approx(2.0)
await client.close()
@pytest.mark.asyncio
async def test_extract_uses_config_max_retries():
"""Client uses max_retries from config when not overridden."""
config = OllamaConfig(
base_url="http://test:11434",
model="test-model",
timeout=10,
max_retries=1,
retry_base_delay=0.0,
)
transport = httpx.MockTransport(
lambda req: _ollama_response("bad output")
)
http = httpx.AsyncClient(transport=transport)
client = OllamaClient(config, http_client=http)
resp = await client.extract(document_text="test", document_type="article")
assert not resp.success
assert len(resp.attempts) == 2 # initial + 1 retry from config
await client.close()
def test_compute_backoff():
"""Backoff computation respects multiplier and max delay."""
assert _compute_backoff(0, 1.0, 10.0, 2.0) == 1.0
assert _compute_backoff(1, 1.0, 10.0, 2.0) == 2.0
assert _compute_backoff(2, 1.0, 10.0, 2.0) == 4.0
assert _compute_backoff(3, 1.0, 10.0, 2.0) == 8.0
assert _compute_backoff(4, 1.0, 10.0, 2.0) == 10.0 # capped at max
def test_is_retryable():
"""Error classification for retry decisions."""
assert _is_retryable("timeout") is True
assert _is_retryable("http_500") is True
assert _is_retryable("connection_error: refused") is True
assert _is_retryable("empty_model_response") is True
assert _is_retryable("invalid_response_json") is True
assert _is_retryable("http_400") is False
assert _is_retryable("http_401") is False
assert _is_retryable("http_403") is False
assert _is_retryable("http_404") is False
assert _is_retryable("http_422") is False
assert _is_retryable(None) is False