"""Ollama client wrapper using structured output format. Sends documents to a local Ollama instance via the /api/chat endpoint with the ``format`` parameter set to the extraction JSON schema, ensuring the model returns schema-compliant JSON. Includes retry logic for invalid or incomplete model responses with exponential backoff, error classification, and full audit preservation. Requirements: 5.1, 5.2, 5.4 """ from __future__ import annotations import asyncio import json import logging import time from dataclasses import dataclass, field import httpx from services.extractor.prompts import ( build_extraction_prompt, get_json_schema, get_prompt_metadata, ) from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction from services.shared.config import OllamaConfig logger = logging.getLogger("ollama_client") # Errors that should NOT be retried — the request itself is bad. _NON_RETRYABLE_ERRORS = frozenset({ "http_400", "http_401", "http_403", "http_404", "http_422", }) def _is_retryable(error: str | None) -> bool: """Determine whether an extraction error warrants a retry.""" if error is None: return False return error not in _NON_RETRYABLE_ERRORS @dataclass class ExtractionAttempt: """Record of a single extraction attempt for audit.""" raw_output: str = "" validation: ValidationReport | None = None error: str | None = None duration_ms: int = 0 model: str = "" retryable: bool = True @dataclass class ExtractionResponse: """Full response from an extraction call, including all attempts.""" success: bool = False result: ExtractionResult | None = None attempts: list[ExtractionAttempt] = field(default_factory=list) prompt_metadata: dict[str, str] = field(default_factory=dict) model: str = "" total_duration_ms: int = 0 def _compute_backoff( attempt_num: int, base_delay: float, max_delay: float, multiplier: float, ) -> float: """Compute exponential backoff delay for a given attempt number.""" delay = base_delay * (multiplier ** attempt_num) return min(delay, max_delay) class OllamaClient: """Async client for Ollama structured extraction. Usage:: config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b") client = OllamaClient(config) response = await client.extract( document_text="Apple reported record earnings...", document_type="article", document_id="abc-123", ) if response.success: print(response.result) """ _config: OllamaConfig _max_retries: int _base_delay: float _max_delay: float _backoff_multiplier: float _owns_client: bool _http: httpx.AsyncClient def __init__( self, config: OllamaConfig, max_retries: int | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: self._config = config self._max_retries = max_retries if max_retries is not None else config.max_retries self._base_delay = config.retry_base_delay self._max_delay = config.retry_max_delay self._backoff_multiplier = config.retry_backoff_multiplier self._owns_client = http_client is None self._http = http_client or httpx.AsyncClient( timeout=httpx.Timeout(config.timeout, read=config.timeout), ) async def close(self) -> None: """Close the underlying HTTP client if we own it.""" if self._owns_client: await self._http.aclose() async def extract( self, document_text: str, document_type: str = "article", document_id: str = "", known_tickers: list[str] | None = None, ) -> ExtractionResponse: """Send a document to Ollama for structured intelligence extraction. Retries up to ``max_retries`` times when the model returns invalid or incomplete JSON. Uses exponential backoff between retries. Non-retryable errors (e.g. HTTP 400) stop retries immediately. Each attempt and its validation result are preserved for audit. Args: document_text: Normalized text content of the document. document_type: One of article, filing, transcript, press_release. document_id: Optional document ID for traceability. known_tickers: Optional ticker hints for the model. Returns: An ``ExtractionResponse`` with the parsed result on success. """ prompts = build_extraction_prompt( document_text=document_text, document_type=document_type, document_id=document_id, known_tickers=known_tickers, ) json_schema = get_json_schema() prompt_meta = get_prompt_metadata() response = ExtractionResponse( prompt_metadata=prompt_meta, model=self._config.model, ) total_start = time.monotonic() for attempt_num in range(self._max_retries + 1): attempt = await self._call_ollama(prompts, json_schema, document_text) response.attempts.append(attempt) if attempt.error is None and attempt.validation and attempt.validation.valid: response.success = True response.result = attempt.validation.parsed break # Check if the error is non-retryable — stop immediately if not _is_retryable(attempt.error): attempt.retryable = False logger.warning( "Non-retryable error for doc %s: %s — stopping retries", document_id or "unknown", attempt.error, ) break if attempt_num < self._max_retries: delay = _compute_backoff( attempt_num, self._base_delay, self._max_delay, self._backoff_multiplier, ) logger.warning( "Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs", attempt_num + 1, self._max_retries + 1, document_id or "unknown", attempt.error or "validation failed", delay, ) await asyncio.sleep(delay) response.total_duration_ms = int((time.monotonic() - total_start) * 1000) return response async def _call_ollama( self, prompts: dict[str, str], json_schema: dict[str, object], document_text: str = "", ) -> ExtractionAttempt: """Make a streaming call to Ollama with early-termination guardrails. Aborts the stream if: - Total generated tokens exceed ``max_tokens`` - No new chunk arrives within ``stall_timeout`` seconds - Repetition loop detected in the last ``loop_window`` tokens """ attempt = ExtractionAttempt(model=self._config.model) start = time.monotonic() payload = { "model": self._config.model, "messages": [ {"role": "system", "content": prompts["system"]}, {"role": "user", "content": prompts["user"]}, ], "format": json_schema, "stream": True, # NOTE: Do NOT set "think": False here. Ollama has a known bug # (issues #14645, #15260) where think=false silently disables # the format constraint for qwen3.5 and gemma4 models, causing # the model to output plain text instead of valid JSON. # Omitting "think" lets the model use thinking mode (slightly # slower but structured output actually works). } url = f"{self._config.base_url}/api/chat" logger.info( "Ollama POST %s model=%s input_chars=%d (streaming)", url, self._config.model, len(prompts.get("user", "")), ) try: req = self._http.build_request("POST", url, json=payload) resp = await self._http.send(req, stream=True) resp.raise_for_status() except httpx.TimeoutException: attempt.error = "timeout" attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt except httpx.HTTPStatusError as exc: attempt.error = f"http_{exc.response.status_code}" attempt.retryable = _is_retryable(attempt.error) attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt except httpx.HTTPError as exc: attempt.error = f"connection_error: {exc}" attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt # Stream and accumulate with guardrails chunks: list[str] = [] token_count = 0 last_chunk_time = time.monotonic() abort_reason: str | None = None try: async for line in resp.aiter_lines(): if not line: continue try: frame = json.loads(line) except json.JSONDecodeError: continue if frame.get("done"): break msg = frame.get("message", {}) token = msg.get("content", "") if isinstance(msg, dict) else "" # During thinking mode, the model emits tokens in msg.thinking # before msg.content. We don't accumulate thinking tokens but # must update last_chunk_time so the stall guard doesn't fire. thinking_token = msg.get("thinking", "") if isinstance(msg, dict) else "" if thinking_token: last_chunk_time = time.monotonic() if not token: continue chunks.append(token) token_count += 1 last_chunk_time = time.monotonic() # Guard: max tokens if token_count > self._config.max_tokens: abort_reason = f"max_tokens_exceeded ({token_count})" break # Guard: repetition loop detection if token_count >= self._config.loop_window: window = chunks[-self._config.loop_window:] unique_ratio = len(set(window)) / len(window) if unique_ratio < self._config.loop_threshold: abort_reason = f"repetition_loop (unique_ratio={unique_ratio:.2f})" break # Guard: stall detection (check between chunks) elapsed_since_last = time.monotonic() - last_chunk_time if elapsed_since_last > self._config.stall_timeout: abort_reason = "stall_timeout" break except httpx.ReadTimeout: abort_reason = "read_timeout" finally: await resp.aclose() attempt.duration_ms = int((time.monotonic() - start) * 1000) if abort_reason: logger.warning( "Stream aborted after %d tokens: %s", token_count, abort_reason, ) attempt.error = abort_reason attempt.raw_output = "".join(chunks) return attempt content = "".join(chunks) attempt.raw_output = content if not content: attempt.error = "empty_model_response" return attempt # Validate against extraction schema attempt.validation = validate_extraction(content, document_text=document_text) if not attempt.validation.valid: attempt.error = "; ".join(attempt.validation.errors) return attempt