"""Ollama client wrapper using structured output format. Sends documents to a local Ollama instance via the /api/chat endpoint with the ``format`` parameter set to the extraction JSON schema, ensuring the model returns schema-compliant JSON. Includes retry logic for invalid or incomplete model responses with exponential backoff, error classification, and full audit preservation. Requirements: 5.1, 5.2, 5.4 """ from __future__ import annotations import asyncio import json import logging import time from dataclasses import dataclass, field import httpx from services.extractor.prompts import ( build_extraction_prompt, get_json_schema, get_prompt_metadata, ) from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction from services.shared.config import OllamaConfig logger = logging.getLogger("ollama_client") # Errors that should NOT be retried — the request itself is bad. _NON_RETRYABLE_ERRORS = frozenset({ "http_400", "http_401", "http_403", "http_404", "http_422", }) def _is_retryable(error: str | None) -> bool: """Determine whether an extraction error warrants a retry.""" if error is None: return False return error not in _NON_RETRYABLE_ERRORS @dataclass class ExtractionAttempt: """Record of a single extraction attempt for audit.""" raw_output: str = "" validation: ValidationReport | None = None error: str | None = None duration_ms: int = 0 model: str = "" retryable: bool = True @dataclass class ExtractionResponse: """Full response from an extraction call, including all attempts.""" success: bool = False result: ExtractionResult | None = None attempts: list[ExtractionAttempt] = field(default_factory=list) prompt_metadata: dict[str, str] = field(default_factory=dict) model: str = "" total_duration_ms: int = 0 def _compute_backoff( attempt_num: int, base_delay: float, max_delay: float, multiplier: float, ) -> float: """Compute exponential backoff delay for a given attempt number.""" delay = base_delay * (multiplier ** attempt_num) return min(delay, max_delay) class OllamaClient: """Async client for Ollama structured extraction. Usage:: config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b") client = OllamaClient(config) response = await client.extract( document_text="Apple reported record earnings...", document_type="article", document_id="abc-123", ) if response.success: print(response.result) """ _config: OllamaConfig _max_retries: int _base_delay: float _max_delay: float _backoff_multiplier: float _owns_client: bool _http: httpx.AsyncClient def __init__( self, config: OllamaConfig, max_retries: int | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: self._config = config self._max_retries = max_retries if max_retries is not None else config.max_retries self._base_delay = config.retry_base_delay self._max_delay = config.retry_max_delay self._backoff_multiplier = config.retry_backoff_multiplier self._owns_client = http_client is None self._http = http_client or httpx.AsyncClient(timeout=config.timeout) async def close(self) -> None: """Close the underlying HTTP client if we own it.""" if self._owns_client: await self._http.aclose() async def extract( self, document_text: str, document_type: str = "article", document_id: str = "", known_tickers: list[str] | None = None, ) -> ExtractionResponse: """Send a document to Ollama for structured intelligence extraction. Retries up to ``max_retries`` times when the model returns invalid or incomplete JSON. Uses exponential backoff between retries. Non-retryable errors (e.g. HTTP 400) stop retries immediately. Each attempt and its validation result are preserved for audit. Args: document_text: Normalized text content of the document. document_type: One of article, filing, transcript, press_release. document_id: Optional document ID for traceability. known_tickers: Optional ticker hints for the model. Returns: An ``ExtractionResponse`` with the parsed result on success. """ prompts = build_extraction_prompt( document_text=document_text, document_type=document_type, document_id=document_id, known_tickers=known_tickers, ) json_schema = get_json_schema() prompt_meta = get_prompt_metadata() response = ExtractionResponse( prompt_metadata=prompt_meta, model=self._config.model, ) total_start = time.monotonic() for attempt_num in range(self._max_retries + 1): attempt = await self._call_ollama(prompts, json_schema, document_text) response.attempts.append(attempt) if attempt.error is None and attempt.validation and attempt.validation.valid: response.success = True response.result = attempt.validation.parsed break # Check if the error is non-retryable — stop immediately if not _is_retryable(attempt.error): attempt.retryable = False logger.warning( "Non-retryable error for doc %s: %s — stopping retries", document_id or "unknown", attempt.error, ) break if attempt_num < self._max_retries: delay = _compute_backoff( attempt_num, self._base_delay, self._max_delay, self._backoff_multiplier, ) logger.warning( "Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs", attempt_num + 1, self._max_retries + 1, document_id or "unknown", attempt.error or "validation failed", delay, ) await asyncio.sleep(delay) response.total_duration_ms = int((time.monotonic() - total_start) * 1000) return response async def _call_ollama( self, prompts: dict[str, str], json_schema: dict[str, object], document_text: str = "", ) -> ExtractionAttempt: """Make a single call to the Ollama /api/chat endpoint.""" attempt = ExtractionAttempt(model=self._config.model) start = time.monotonic() payload = { "model": self._config.model, "messages": [ {"role": "system", "content": prompts["system"]}, {"role": "user", "content": prompts["user"]}, ], "format": json_schema, "stream": False, } url = f"{self._config.base_url}/api/chat" logger.info( "Ollama POST %s model=%s input_chars=%d", url, self._config.model, len(prompts.get("user", "")), ) try: resp = await self._http.post(url, json=payload) _ = resp.raise_for_status() except httpx.TimeoutException: attempt.error = "timeout" attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt except httpx.HTTPStatusError as exc: attempt.error = f"http_{exc.response.status_code}" attempt.retryable = _is_retryable(attempt.error) attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt except httpx.HTTPError as exc: attempt.error = f"connection_error: {exc}" attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt attempt.duration_ms = int((time.monotonic() - start) * 1000) # Parse the Ollama response envelope try: body: dict[str, object] = resp.json() except json.JSONDecodeError: attempt.error = "invalid_response_json" attempt.raw_output = resp.text return attempt msg = body.get("message") content: str = msg.get("content", "") if isinstance(msg, dict) else "" attempt.raw_output = content if not content: attempt.error = "empty_model_response" return attempt # Validate against extraction schema attempt.validation = validate_extraction(content, document_text=document_text) if not attempt.validation.valid: attempt.error = "; ".join(attempt.validation.errors) return attempt