"""Tests for the Ollama client wrapper.""" import json from unittest.mock import AsyncMock, patch import httpx import pytest from services.extractor.client import ( OllamaClient, _compute_backoff, _is_retryable, ) from services.shared.config import OllamaConfig def _valid_extraction_json() -> str: """Minimal valid extraction result as JSON string.""" return json.dumps({ "summary": "Apple beat earnings expectations.", "companies": [ { "ticker": "AAPL", "company_name": "Apple Inc.", "relevance": 0.95, "sentiment": "positive", "impact_score": 0.7, "impact_horizon": "1d_30d", "catalyst_type": "earnings", "key_facts": ["Revenue up 12%"], "risks": [], "evidence_spans": ["Apple beat expectations"], } ], "macro_themes": ["ai_capex"], "novelty_score": 0.6, "confidence": 0.85, "extraction_warnings": [], }) def _ollama_response(content: str) -> httpx.Response: """Build a fake Ollama /api/chat response.""" body = {"message": {"role": "assistant", "content": content}} return httpx.Response(200, json=body) def _make_config() -> OllamaConfig: return OllamaConfig( base_url="http://test:11434", model="test-model", timeout=10, retry_base_delay=0.0, retry_max_delay=0.0, retry_backoff_multiplier=2.0, ) @pytest.mark.asyncio async def test_extract_success(): """Successful extraction on first attempt.""" transport = httpx.MockTransport( lambda req: _ollama_response(_valid_extraction_json()) ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), http_client=http) resp = await client.extract( document_text="Apple reported record Q4 earnings.", document_type="article", document_id="doc-1", ) assert resp.success assert resp.result is not None assert resp.result.companies[0].ticker == "AAPL" assert len(resp.attempts) == 1 assert resp.attempts[0].error is None assert resp.model == "test-model" assert resp.prompt_metadata["prompt_version"] await client.close() @pytest.mark.asyncio async def test_extract_retry_on_invalid_json(): """Client retries when model returns invalid JSON, then succeeds.""" call_count = 0 def handler(request: httpx.Request) -> httpx.Response: nonlocal call_count call_count += 1 if call_count == 1: return _ollama_response("not valid json {{{") return _ollama_response(_valid_extraction_json()) transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), max_retries=2, http_client=http) resp = await client.extract(document_text="test", document_type="article") assert resp.success assert len(resp.attempts) == 2 assert resp.attempts[0].error is not None assert resp.attempts[1].error is None await client.close() @pytest.mark.asyncio async def test_extract_all_retries_exhausted(): """All retries fail — response indicates failure with all attempts recorded.""" transport = httpx.MockTransport( lambda req: _ollama_response("bad output") ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), max_retries=1, http_client=http) resp = await client.extract(document_text="test", document_type="article") assert not resp.success assert resp.result is None assert len(resp.attempts) == 2 # initial + 1 retry await client.close() @pytest.mark.asyncio async def test_extract_http_timeout(): """HTTP timeout is captured as an error.""" def handler(request: httpx.Request) -> httpx.Response: raise httpx.ReadTimeout("timed out") transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), max_retries=0, http_client=http) resp = await client.extract(document_text="test", document_type="article") assert not resp.success assert resp.attempts[0].error == "timeout" await client.close() @pytest.mark.asyncio async def test_extract_http_500(): """HTTP 500 is captured as an error.""" transport = httpx.MockTransport( lambda req: httpx.Response(500, text="Internal Server Error") ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), max_retries=0, http_client=http) resp = await client.extract(document_text="test", document_type="article") assert not resp.success assert "500" in (resp.attempts[0].error or "") await client.close() @pytest.mark.asyncio async def test_extract_empty_model_response(): """Empty content from model is treated as an error.""" transport = httpx.MockTransport( lambda req: _ollama_response("") ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), max_retries=0, http_client=http) resp = await client.extract(document_text="test", document_type="article") assert not resp.success assert resp.attempts[0].error == "empty_model_response" await client.close() @pytest.mark.asyncio async def test_extract_schema_validation_failure(): """Model returns valid JSON but missing required fields — normalization fills defaults.""" bad_extraction = json.dumps({"summary": "test"}) # missing companies, etc. transport = httpx.MockTransport( lambda req: _ollama_response(bad_extraction) ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), max_retries=0, http_client=http) resp = await client.extract(document_text="test", document_type="article") # Normalization fills missing fields with defaults, so validation passes assert resp.success assert resp.result is not None assert resp.attempts[0].validation is not None assert resp.attempts[0].validation.valid await client.close() @pytest.mark.asyncio async def test_extract_with_known_tickers(): """Known tickers are passed through to the prompt builder.""" transport = httpx.MockTransport( lambda req: _ollama_response(_valid_extraction_json()) ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), http_client=http) resp = await client.extract( document_text="test", document_type="article", known_tickers=["AAPL", "MSFT"], ) assert resp.success await client.close() @pytest.mark.asyncio async def test_extract_sends_structured_format(): """The request payload includes think=False and stream=False (no format key due to Ollama bug #14645).""" captured_payload: dict[str, object] = {} def handler(request: httpx.Request) -> httpx.Response: captured_payload.update(json.loads(request.content)) return _ollama_response(_valid_extraction_json()) transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), http_client=http) await client.extract(document_text="test", document_type="article") # format key is intentionally omitted (Ollama bug #14645 with think=false) assert "format" not in captured_payload assert captured_payload["think"] is False assert captured_payload["stream"] is False assert captured_payload["model"] == "test-model" await client.close() @pytest.mark.asyncio async def test_extract_non_retryable_http_400_stops_immediately(): """HTTP 400 is non-retryable — client stops after first attempt.""" call_count = 0 def handler(request: httpx.Request) -> httpx.Response: nonlocal call_count call_count += 1 return httpx.Response(400, text="Bad Request") transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), max_retries=3, http_client=http) resp = await client.extract(document_text="test", document_type="article") assert not resp.success assert len(resp.attempts) == 1 # no retries for 400 assert resp.attempts[0].error == "http_400" assert not resp.attempts[0].retryable assert call_count == 1 await client.close() @pytest.mark.asyncio async def test_extract_retryable_http_500_retries(): """HTTP 500 is retryable — client retries up to max_retries.""" call_count = 0 def handler(request: httpx.Request) -> httpx.Response: nonlocal call_count call_count += 1 if call_count <= 2: return httpx.Response(500, text="Internal Server Error") return _ollama_response(_valid_extraction_json()) transport = httpx.MockTransport(handler) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), max_retries=3, http_client=http) resp = await client.extract(document_text="test", document_type="article") assert resp.success assert len(resp.attempts) == 3 assert resp.attempts[0].retryable is True assert resp.attempts[1].retryable is True assert call_count == 3 await client.close() @pytest.mark.asyncio async def test_extract_retryable_field_on_success(): """Successful attempt has retryable=True (default).""" transport = httpx.MockTransport( lambda req: _ollama_response(_valid_extraction_json()) ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(_make_config(), http_client=http) resp = await client.extract(document_text="test", document_type="article") assert resp.success assert resp.attempts[0].retryable is True await client.close() @pytest.mark.asyncio async def test_extract_backoff_is_called_between_retries(): """asyncio.sleep is called with increasing delays between retries.""" config = OllamaConfig( base_url="http://test:11434", model="test-model", timeout=10, retry_base_delay=1.0, retry_max_delay=10.0, retry_backoff_multiplier=2.0, ) transport = httpx.MockTransport( lambda req: _ollama_response("bad output") ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(config, max_retries=2, http_client=http) with patch("services.extractor.client.asyncio.sleep", new_callable=AsyncMock) as mock_sleep: resp = await client.extract(document_text="test", document_type="article") assert not resp.success assert len(resp.attempts) == 3 # initial + 2 retries assert mock_sleep.call_count == 2 # First backoff: 1.0 * 2^0 = 1.0 assert mock_sleep.call_args_list[0].args[0] == pytest.approx(1.0) # Second backoff: 1.0 * 2^1 = 2.0 assert mock_sleep.call_args_list[1].args[0] == pytest.approx(2.0) await client.close() @pytest.mark.asyncio async def test_extract_uses_config_max_retries(): """Client uses max_retries from config when not overridden.""" config = OllamaConfig( base_url="http://test:11434", model="test-model", timeout=10, max_retries=1, retry_base_delay=0.0, ) transport = httpx.MockTransport( lambda req: _ollama_response("bad output") ) http = httpx.AsyncClient(transport=transport) client = OllamaClient(config, http_client=http) resp = await client.extract(document_text="test", document_type="article") assert not resp.success assert len(resp.attempts) == 2 # initial + 1 retry from config await client.close() def test_compute_backoff(): """Backoff computation respects multiplier and max delay.""" assert _compute_backoff(0, 1.0, 10.0, 2.0) == 1.0 assert _compute_backoff(1, 1.0, 10.0, 2.0) == 2.0 assert _compute_backoff(2, 1.0, 10.0, 2.0) == 4.0 assert _compute_backoff(3, 1.0, 10.0, 2.0) == 8.0 assert _compute_backoff(4, 1.0, 10.0, 2.0) == 10.0 # capped at max def test_is_retryable(): """Error classification for retry decisions.""" assert _is_retryable("timeout") is True assert _is_retryable("http_500") is True assert _is_retryable("connection_error: refused") is True assert _is_retryable("empty_model_response") is True assert _is_retryable("invalid_response_json") is True assert _is_retryable("http_400") is False assert _is_retryable("http_401") is False assert _is_retryable("http_403") is False assert _is_retryable("http_404") is False assert _is_retryable("http_422") is False assert _is_retryable(None) is False