c85c0068a2
- Replace all datetime.utcnow() with datetime.now(tz=timezone.utc) across 8 files - Fix 12 failing tests to match current implementation behavior - Fix pytest_plugins in non-top-level conftest (moved to root conftest.py) - Auto-fix 189 lint issues (import sorting, unused imports) - Add CI/CD pipeline infrastructure (ARC, ArgoCD, Kargo manifests) - Add values-beta.yaml and values-paper.yaml for staged deployments - Update GitHub Actions workflow to use self-hosted-gremlin runners - Add integration-test job to CI pipeline Result: 1596 passed, 0 failed, 0 warnings
391 lines
12 KiB
Python
391 lines
12 KiB
Python
"""Tests for the Ollama client wrapper."""
|
|
import json
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from services.extractor.client import (
|
|
OllamaClient,
|
|
_compute_backoff,
|
|
_is_retryable,
|
|
)
|
|
from services.shared.config import OllamaConfig
|
|
|
|
|
|
def _valid_extraction_json() -> str:
|
|
"""Minimal valid extraction result as JSON string."""
|
|
return json.dumps({
|
|
"summary": "Apple beat earnings expectations.",
|
|
"companies": [
|
|
{
|
|
"ticker": "AAPL",
|
|
"company_name": "Apple Inc.",
|
|
"relevance": 0.95,
|
|
"sentiment": "positive",
|
|
"impact_score": 0.7,
|
|
"impact_horizon": "1d_30d",
|
|
"catalyst_type": "earnings",
|
|
"key_facts": ["Revenue up 12%"],
|
|
"risks": [],
|
|
"evidence_spans": ["Apple beat expectations"],
|
|
}
|
|
],
|
|
"macro_themes": ["ai_capex"],
|
|
"novelty_score": 0.6,
|
|
"confidence": 0.85,
|
|
"extraction_warnings": [],
|
|
})
|
|
|
|
|
|
def _ollama_response(content: str) -> httpx.Response:
|
|
"""Build a fake Ollama /api/chat response."""
|
|
body = {"message": {"role": "assistant", "content": content}}
|
|
return httpx.Response(200, json=body)
|
|
|
|
|
|
def _make_config() -> OllamaConfig:
|
|
return OllamaConfig(
|
|
base_url="http://test:11434",
|
|
model="test-model",
|
|
timeout=10,
|
|
retry_base_delay=0.0,
|
|
retry_max_delay=0.0,
|
|
retry_backoff_multiplier=2.0,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_success():
|
|
"""Successful extraction on first attempt."""
|
|
transport = httpx.MockTransport(
|
|
lambda req: _ollama_response(_valid_extraction_json())
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), http_client=http)
|
|
|
|
resp = await client.extract(
|
|
document_text="Apple reported record Q4 earnings.",
|
|
document_type="article",
|
|
document_id="doc-1",
|
|
)
|
|
|
|
assert resp.success
|
|
assert resp.result is not None
|
|
assert resp.result.companies[0].ticker == "AAPL"
|
|
assert len(resp.attempts) == 1
|
|
assert resp.attempts[0].error is None
|
|
assert resp.model == "test-model"
|
|
assert resp.prompt_metadata["prompt_version"]
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_retry_on_invalid_json():
|
|
"""Client retries when model returns invalid JSON, then succeeds."""
|
|
call_count = 0
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return _ollama_response("not valid json {{{")
|
|
return _ollama_response(_valid_extraction_json())
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), max_retries=2, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert resp.success
|
|
assert len(resp.attempts) == 2
|
|
assert resp.attempts[0].error is not None
|
|
assert resp.attempts[1].error is None
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_all_retries_exhausted():
|
|
"""All retries fail — response indicates failure with all attempts recorded."""
|
|
transport = httpx.MockTransport(
|
|
lambda req: _ollama_response("bad output")
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), max_retries=1, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert not resp.success
|
|
assert resp.result is None
|
|
assert len(resp.attempts) == 2 # initial + 1 retry
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_http_timeout():
|
|
"""HTTP timeout is captured as an error."""
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
raise httpx.ReadTimeout("timed out")
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), max_retries=0, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert not resp.success
|
|
assert resp.attempts[0].error == "timeout"
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_http_500():
|
|
"""HTTP 500 is captured as an error."""
|
|
transport = httpx.MockTransport(
|
|
lambda req: httpx.Response(500, text="Internal Server Error")
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), max_retries=0, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert not resp.success
|
|
assert "500" in (resp.attempts[0].error or "")
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_empty_model_response():
|
|
"""Empty content from model is treated as an error."""
|
|
transport = httpx.MockTransport(
|
|
lambda req: _ollama_response("")
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), max_retries=0, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert not resp.success
|
|
assert resp.attempts[0].error == "empty_model_response"
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_schema_validation_failure():
|
|
"""Model returns valid JSON but missing required fields — normalization fills defaults."""
|
|
bad_extraction = json.dumps({"summary": "test"}) # missing companies, etc.
|
|
transport = httpx.MockTransport(
|
|
lambda req: _ollama_response(bad_extraction)
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), max_retries=0, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
# Normalization fills missing fields with defaults, so validation passes
|
|
assert resp.success
|
|
assert resp.result is not None
|
|
assert resp.attempts[0].validation is not None
|
|
assert resp.attempts[0].validation.valid
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_with_known_tickers():
|
|
"""Known tickers are passed through to the prompt builder."""
|
|
transport = httpx.MockTransport(
|
|
lambda req: _ollama_response(_valid_extraction_json())
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), http_client=http)
|
|
|
|
resp = await client.extract(
|
|
document_text="test",
|
|
document_type="article",
|
|
known_tickers=["AAPL", "MSFT"],
|
|
)
|
|
|
|
assert resp.success
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_sends_structured_format():
|
|
"""The request payload includes think=False and stream=False (no format key due to Ollama bug #14645)."""
|
|
captured_payload: dict[str, object] = {}
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
captured_payload.update(json.loads(request.content))
|
|
return _ollama_response(_valid_extraction_json())
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), http_client=http)
|
|
|
|
await client.extract(document_text="test", document_type="article")
|
|
|
|
# format key is intentionally omitted (Ollama bug #14645 with think=false)
|
|
assert "format" not in captured_payload
|
|
assert captured_payload["think"] is False
|
|
assert captured_payload["stream"] is False
|
|
assert captured_payload["model"] == "test-model"
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_non_retryable_http_400_stops_immediately():
|
|
"""HTTP 400 is non-retryable — client stops after first attempt."""
|
|
call_count = 0
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return httpx.Response(400, text="Bad Request")
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), max_retries=3, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert not resp.success
|
|
assert len(resp.attempts) == 1 # no retries for 400
|
|
assert resp.attempts[0].error == "http_400"
|
|
assert not resp.attempts[0].retryable
|
|
assert call_count == 1
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_retryable_http_500_retries():
|
|
"""HTTP 500 is retryable — client retries up to max_retries."""
|
|
call_count = 0
|
|
|
|
def handler(request: httpx.Request) -> httpx.Response:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count <= 2:
|
|
return httpx.Response(500, text="Internal Server Error")
|
|
return _ollama_response(_valid_extraction_json())
|
|
|
|
transport = httpx.MockTransport(handler)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), max_retries=3, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert resp.success
|
|
assert len(resp.attempts) == 3
|
|
assert resp.attempts[0].retryable is True
|
|
assert resp.attempts[1].retryable is True
|
|
assert call_count == 3
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_retryable_field_on_success():
|
|
"""Successful attempt has retryable=True (default)."""
|
|
transport = httpx.MockTransport(
|
|
lambda req: _ollama_response(_valid_extraction_json())
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(_make_config(), http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert resp.success
|
|
assert resp.attempts[0].retryable is True
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_backoff_is_called_between_retries():
|
|
"""asyncio.sleep is called with increasing delays between retries."""
|
|
config = OllamaConfig(
|
|
base_url="http://test:11434",
|
|
model="test-model",
|
|
timeout=10,
|
|
retry_base_delay=1.0,
|
|
retry_max_delay=10.0,
|
|
retry_backoff_multiplier=2.0,
|
|
)
|
|
transport = httpx.MockTransport(
|
|
lambda req: _ollama_response("bad output")
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(config, max_retries=2, http_client=http)
|
|
|
|
with patch("services.extractor.client.asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert not resp.success
|
|
assert len(resp.attempts) == 3 # initial + 2 retries
|
|
assert mock_sleep.call_count == 2
|
|
# First backoff: 1.0 * 2^0 = 1.0
|
|
assert mock_sleep.call_args_list[0].args[0] == pytest.approx(1.0)
|
|
# Second backoff: 1.0 * 2^1 = 2.0
|
|
assert mock_sleep.call_args_list[1].args[0] == pytest.approx(2.0)
|
|
|
|
await client.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_uses_config_max_retries():
|
|
"""Client uses max_retries from config when not overridden."""
|
|
config = OllamaConfig(
|
|
base_url="http://test:11434",
|
|
model="test-model",
|
|
timeout=10,
|
|
max_retries=1,
|
|
retry_base_delay=0.0,
|
|
)
|
|
transport = httpx.MockTransport(
|
|
lambda req: _ollama_response("bad output")
|
|
)
|
|
http = httpx.AsyncClient(transport=transport)
|
|
client = OllamaClient(config, http_client=http)
|
|
|
|
resp = await client.extract(document_text="test", document_type="article")
|
|
|
|
assert not resp.success
|
|
assert len(resp.attempts) == 2 # initial + 1 retry from config
|
|
|
|
await client.close()
|
|
|
|
|
|
def test_compute_backoff():
|
|
"""Backoff computation respects multiplier and max delay."""
|
|
assert _compute_backoff(0, 1.0, 10.0, 2.0) == 1.0
|
|
assert _compute_backoff(1, 1.0, 10.0, 2.0) == 2.0
|
|
assert _compute_backoff(2, 1.0, 10.0, 2.0) == 4.0
|
|
assert _compute_backoff(3, 1.0, 10.0, 2.0) == 8.0
|
|
assert _compute_backoff(4, 1.0, 10.0, 2.0) == 10.0 # capped at max
|
|
|
|
|
|
def test_is_retryable():
|
|
"""Error classification for retry decisions."""
|
|
assert _is_retryable("timeout") is True
|
|
assert _is_retryable("http_500") is True
|
|
assert _is_retryable("connection_error: refused") is True
|
|
assert _is_retryable("empty_model_response") is True
|
|
assert _is_retryable("invalid_response_json") is True
|
|
assert _is_retryable("http_400") is False
|
|
assert _is_retryable("http_401") is False
|
|
assert _is_retryable("http_403") is False
|
|
assert _is_retryable("http_404") is False
|
|
assert _is_retryable("http_422") is False
|
|
assert _is_retryable(None) is False
|