Files
stonks-oracle/services/extractor/client.py
T

336 lines
11 KiB
Python

"""Ollama client wrapper using structured output format.
Sends documents to a local Ollama instance via the /api/chat endpoint
with the ``format`` parameter set to the extraction JSON schema, ensuring
the model returns schema-compliant JSON.
Includes retry logic for invalid or incomplete model responses with
exponential backoff, error classification, and full audit preservation.
Requirements: 5.1, 5.2, 5.4
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from dataclasses import dataclass, field
import httpx
from services.extractor.prompts import (
build_extraction_prompt,
get_json_schema,
get_prompt_metadata,
)
from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction
from services.shared.config import OllamaConfig
logger = logging.getLogger("ollama_client")
# Errors that should NOT be retried — the request itself is bad.
_NON_RETRYABLE_ERRORS = frozenset({
"http_400",
"http_401",
"http_403",
"http_404",
"http_422",
})
def _is_retryable(error: str | None) -> bool:
"""Determine whether an extraction error warrants a retry."""
if error is None:
return False
return error not in _NON_RETRYABLE_ERRORS
@dataclass
class ExtractionAttempt:
"""Record of a single extraction attempt for audit."""
raw_output: str = ""
validation: ValidationReport | None = None
error: str | None = None
duration_ms: int = 0
model: str = ""
retryable: bool = True
@dataclass
class ExtractionResponse:
"""Full response from an extraction call, including all attempts."""
success: bool = False
result: ExtractionResult | None = None
attempts: list[ExtractionAttempt] = field(default_factory=list)
prompt_metadata: dict[str, str] = field(default_factory=dict)
model: str = ""
total_duration_ms: int = 0
def _compute_backoff(
attempt_num: int,
base_delay: float,
max_delay: float,
multiplier: float,
) -> float:
"""Compute exponential backoff delay for a given attempt number."""
delay = base_delay * (multiplier ** attempt_num)
return min(delay, max_delay)
class OllamaClient:
"""Async client for Ollama structured extraction.
Usage::
config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b")
client = OllamaClient(config)
response = await client.extract(
document_text="Apple reported record earnings...",
document_type="article",
document_id="abc-123",
)
if response.success:
print(response.result)
"""
_config: OllamaConfig
_max_retries: int
_base_delay: float
_max_delay: float
_backoff_multiplier: float
_owns_client: bool
_http: httpx.AsyncClient
def __init__(
self,
config: OllamaConfig,
max_retries: int | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
self._config = config
self._max_retries = max_retries if max_retries is not None else config.max_retries
self._base_delay = config.retry_base_delay
self._max_delay = config.retry_max_delay
self._backoff_multiplier = config.retry_backoff_multiplier
self._owns_client = http_client is None
self._http = http_client or httpx.AsyncClient(
timeout=httpx.Timeout(config.timeout, read=config.timeout),
)
async def close(self) -> None:
"""Close the underlying HTTP client if we own it."""
if self._owns_client:
await self._http.aclose()
async def extract(
self,
document_text: str,
document_type: str = "article",
document_id: str = "",
known_tickers: list[str] | None = None,
) -> ExtractionResponse:
"""Send a document to Ollama for structured intelligence extraction.
Retries up to ``max_retries`` times when the model returns invalid
or incomplete JSON. Uses exponential backoff between retries.
Non-retryable errors (e.g. HTTP 400) stop retries immediately.
Each attempt and its validation result are preserved for audit.
Args:
document_text: Normalized text content of the document.
document_type: One of article, filing, transcript, press_release.
document_id: Optional document ID for traceability.
known_tickers: Optional ticker hints for the model.
Returns:
An ``ExtractionResponse`` with the parsed result on success.
"""
prompts = build_extraction_prompt(
document_text=document_text,
document_type=document_type,
document_id=document_id,
known_tickers=known_tickers,
)
json_schema = get_json_schema()
prompt_meta = get_prompt_metadata()
response = ExtractionResponse(
prompt_metadata=prompt_meta,
model=self._config.model,
)
total_start = time.monotonic()
for attempt_num in range(self._max_retries + 1):
attempt = await self._call_ollama(prompts, json_schema, document_text)
response.attempts.append(attempt)
if attempt.error is None and attempt.validation and attempt.validation.valid:
response.success = True
response.result = attempt.validation.parsed
break
# Check if the error is non-retryable — stop immediately
if not _is_retryable(attempt.error):
attempt.retryable = False
logger.warning(
"Non-retryable error for doc %s: %s — stopping retries",
document_id or "unknown",
attempt.error,
)
break
if attempt_num < self._max_retries:
delay = _compute_backoff(
attempt_num,
self._base_delay,
self._max_delay,
self._backoff_multiplier,
)
logger.warning(
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1,
self._max_retries + 1,
document_id or "unknown",
attempt.error or "validation failed",
delay,
)
await asyncio.sleep(delay)
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
return response
async def _call_ollama(
self,
prompts: dict[str, str],
json_schema: dict[str, object],
document_text: str = "",
) -> ExtractionAttempt:
"""Make a streaming call to Ollama with early-termination guardrails.
Aborts the stream if:
- Total generated tokens exceed ``max_tokens``
- No new chunk arrives within ``stall_timeout`` seconds
- Repetition loop detected in the last ``loop_window`` tokens
"""
attempt = ExtractionAttempt(model=self._config.model)
start = time.monotonic()
payload = {
"model": self._config.model,
"messages": [
{"role": "system", "content": prompts["system"]},
{"role": "user", "content": prompts["user"]},
],
"format": json_schema,
"stream": True,
# NOTE: Do NOT set "think": False here. Ollama has a known bug
# (issues #14645, #15260) where think=false silently disables
# the format constraint for qwen3.5 and gemma4 models, causing
# the model to output plain text instead of valid JSON.
# Omitting "think" lets the model use thinking mode (slightly
# slower but structured output actually works).
}
url = f"{self._config.base_url}/api/chat"
logger.info(
"Ollama POST %s model=%s input_chars=%d (streaming)",
url, self._config.model, len(prompts.get("user", "")),
)
try:
req = self._http.build_request("POST", url, json=payload)
resp = await self._http.send(req, stream=True)
resp.raise_for_status()
except httpx.TimeoutException:
attempt.error = "timeout"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPStatusError as exc:
attempt.error = f"http_{exc.response.status_code}"
attempt.retryable = _is_retryable(attempt.error)
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
except httpx.HTTPError as exc:
attempt.error = f"connection_error: {exc}"
attempt.duration_ms = int((time.monotonic() - start) * 1000)
return attempt
# Stream and accumulate with guardrails
chunks: list[str] = []
token_count = 0
last_chunk_time = time.monotonic()
abort_reason: str | None = None
try:
async for line in resp.aiter_lines():
if not line:
continue
try:
frame = json.loads(line)
except json.JSONDecodeError:
continue
if frame.get("done"):
break
msg = frame.get("message", {})
token = msg.get("content", "") if isinstance(msg, dict) else ""
if not token:
continue
chunks.append(token)
token_count += 1
last_chunk_time = time.monotonic()
# Guard: max tokens
if token_count > self._config.max_tokens:
abort_reason = f"max_tokens_exceeded ({token_count})"
break
# Guard: repetition loop detection
if token_count >= self._config.loop_window:
window = chunks[-self._config.loop_window:]
unique_ratio = len(set(window)) / len(window)
if unique_ratio < self._config.loop_threshold:
abort_reason = f"repetition_loop (unique_ratio={unique_ratio:.2f})"
break
# Guard: stall detection (check between chunks)
elapsed_since_last = time.monotonic() - last_chunk_time
if elapsed_since_last > self._config.stall_timeout:
abort_reason = "stall_timeout"
break
except httpx.ReadTimeout:
abort_reason = "read_timeout"
finally:
await resp.aclose()
attempt.duration_ms = int((time.monotonic() - start) * 1000)
if abort_reason:
logger.warning(
"Stream aborted after %d tokens: %s", token_count, abort_reason,
)
attempt.error = abort_reason
attempt.raw_output = "".join(chunks)
return attempt
content = "".join(chunks)
attempt.raw_output = content
if not content:
attempt.error = "empty_model_response"
return attempt
# Validate against extraction schema
attempt.validation = validate_extraction(content, document_text=document_text)
if not attempt.validation.valid:
attempt.error = "; ".join(attempt.validation.errors)
return attempt