"""Ollama client wrapper for document intelligence extraction. Sends documents to a local Ollama instance via the /api/chat endpoint with think=false for speed. Uses json-repair to fix common JSON syntax issues in model output since the Ollama format constraint is broken with think=false on qwen3.5 models (Ollama bug #14645). Includes retry logic for invalid or incomplete model responses with exponential backoff, error classification, and full audit preservation. Requirements: 5.1, 5.2, 5.4 """ from __future__ import annotations import asyncio import json import logging import re import time from dataclasses import dataclass, field import httpx from json_repair import repair_json from services.extractor.prompts import ( build_extraction_prompt, get_json_schema, get_prompt_metadata, ) from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction from services.shared.config import OllamaConfig logger = logging.getLogger("ollama_client") # Errors that should NOT be retried — the request itself is bad. _NON_RETRYABLE_ERRORS = frozenset({ "http_400", "http_401", "http_403", "http_404", "http_422", }) def _is_retryable(error: str | None) -> bool: """Determine whether an extraction error warrants a retry.""" if error is None: return False return error not in _NON_RETRYABLE_ERRORS _FENCE_RE = re.compile(r"^```(?:json)?\s*\n?(.*?)\n?\s*```\s*$", re.DOTALL) def _strip_markdown_fences(text: str) -> str: """Remove ```json ... ``` wrappers if present.""" m = _FENCE_RE.match(text.strip()) return m.group(1) if m else text def _repair_json(text: str) -> str: """Try json.loads first; if it fails, repair with json-repair.""" try: json.loads(text) return text # already valid except (json.JSONDecodeError, ValueError): pass try: repaired = repair_json(text, return_objects=False) logger.info("JSON repaired successfully (%d -> %d chars)", len(text), len(repaired)) return repaired except Exception: logger.warning("JSON repair failed, returning original text") return text @dataclass class ExtractionAttempt: """Record of a single extraction attempt for audit.""" raw_output: str = "" validation: ValidationReport | None = None error: str | None = None duration_ms: int = 0 model: str = "" retryable: bool = True @dataclass class ExtractionResponse: """Full response from an extraction call, including all attempts.""" success: bool = False result: ExtractionResult | None = None attempts: list[ExtractionAttempt] = field(default_factory=list) prompt_metadata: dict[str, str] = field(default_factory=dict) model: str = "" total_duration_ms: int = 0 def _compute_backoff( attempt_num: int, base_delay: float, max_delay: float, multiplier: float, ) -> float: """Compute exponential backoff delay for a given attempt number.""" delay = base_delay * (multiplier ** attempt_num) return min(delay, max_delay) class OllamaClient: """Async client for Ollama structured extraction. Usage:: config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b") client = OllamaClient(config) response = await client.extract( document_text="Apple reported record earnings...", document_type="article", document_id="abc-123", ) if response.success: print(response.result) """ _config: OllamaConfig _max_retries: int _base_delay: float _max_delay: float _backoff_multiplier: float _owns_client: bool _http: httpx.AsyncClient def __init__( self, config: OllamaConfig, max_retries: int | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: self._config = config self._max_retries = max_retries if max_retries is not None else config.max_retries self._base_delay = config.retry_base_delay self._max_delay = config.retry_max_delay self._backoff_multiplier = config.retry_backoff_multiplier self._owns_client = http_client is None self._http = http_client or httpx.AsyncClient( timeout=httpx.Timeout(config.timeout, read=config.timeout), ) async def close(self) -> None: """Close the underlying HTTP client if we own it.""" if self._owns_client: await self._http.aclose() async def extract( self, document_text: str, document_type: str = "article", document_id: str = "", known_tickers: list[str] | None = None, ) -> ExtractionResponse: """Send a document to Ollama for structured intelligence extraction. Retries up to ``max_retries`` times when the model returns invalid or incomplete JSON. Uses exponential backoff between retries. Non-retryable errors (e.g. HTTP 400) stop retries immediately. Each attempt and its validation result are preserved for audit. Args: document_text: Normalized text content of the document. document_type: One of article, filing, transcript, press_release. document_id: Optional document ID for traceability. known_tickers: Optional ticker hints for the model. Returns: An ``ExtractionResponse`` with the parsed result on success. """ prompts = build_extraction_prompt( document_text=document_text, document_type=document_type, document_id=document_id, known_tickers=known_tickers, ) json_schema = get_json_schema() prompt_meta = get_prompt_metadata() response = ExtractionResponse( prompt_metadata=prompt_meta, model=self._config.model, ) total_start = time.monotonic() for attempt_num in range(self._max_retries + 1): attempt = await self._call_ollama(prompts, json_schema, document_text) response.attempts.append(attempt) if attempt.error is None and attempt.validation and attempt.validation.valid: response.success = True response.result = attempt.validation.parsed break # Check if the error is non-retryable — stop immediately if not _is_retryable(attempt.error): attempt.retryable = False logger.warning( "Non-retryable error for doc %s: %s — stopping retries", document_id or "unknown", attempt.error, ) break if attempt_num < self._max_retries: delay = _compute_backoff( attempt_num, self._base_delay, self._max_delay, self._backoff_multiplier, ) logger.warning( "Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs", attempt_num + 1, self._max_retries + 1, document_id or "unknown", attempt.error or "validation failed", delay, ) await asyncio.sleep(delay) response.total_duration_ms = int((time.monotonic() - total_start) * 1000) return response async def _call_ollama( self, prompts: dict[str, str], json_schema: dict[str, object], document_text: str = "", ) -> ExtractionAttempt: """Call Ollama with think=false for speed, then repair any malformed JSON. Uses think=false to avoid the 2-4 minute thinking overhead. Does NOT use the format parameter (Ollama bug #14645 silently ignores format when think=false on qwen3.5 models). Instead, relies on the prompt to produce JSON and repairs common syntax issues with json-repair. """ attempt = ExtractionAttempt(model=self._config.model) start = time.monotonic() payload = { "model": self._config.model, "messages": [ {"role": "system", "content": prompts["system"]}, {"role": "user", "content": prompts["user"]}, ], "stream": False, "think": False, "options": { "num_predict": 16384, }, } url = f"{self._config.base_url}/api/chat" logger.info( "Ollama POST %s model=%s input_chars=%d", url, self._config.model, len(prompts.get("user", "")), ) try: resp = await self._http.post(url, json=payload) resp.raise_for_status() except httpx.TimeoutException: attempt.error = "timeout" attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt except httpx.HTTPStatusError as exc: attempt.error = f"http_{exc.response.status_code}" attempt.retryable = _is_retryable(attempt.error) attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt except httpx.HTTPError as exc: attempt.error = f"connection_error: {exc}" attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt attempt.duration_ms = int((time.monotonic() - start) * 1000) try: data = resp.json() except Exception: attempt.error = "invalid_response_json" attempt.raw_output = resp.text[:2000] return attempt content = data.get("message", {}).get("content", "") attempt.raw_output = content if not content: attempt.error = "empty_model_response" return attempt # Strip markdown fences if present (model sometimes wraps in ```json ... ```) content = _strip_markdown_fences(content) # Try json.loads first; if it fails, attempt repair content = _repair_json(content) # Validate against extraction schema attempt.validation = validate_extraction(content, document_text=document_text) if not attempt.validation.valid: attempt.error = "; ".join(attempt.validation.errors) return attempt