phase 14-15: docker build validation and helm deployment
This commit is contained in:
@@ -0,0 +1,268 @@
|
||||
"""Ollama client wrapper using structured output format.
|
||||
|
||||
Sends documents to a local Ollama instance via the /api/chat endpoint
|
||||
with the ``format`` parameter set to the extraction JSON schema, ensuring
|
||||
the model returns schema-compliant JSON.
|
||||
|
||||
Includes retry logic for invalid or incomplete model responses with
|
||||
exponential backoff, error classification, and full audit preservation.
|
||||
|
||||
Requirements: 5.1, 5.2, 5.4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
from services.extractor.prompts import (
|
||||
build_extraction_prompt,
|
||||
get_json_schema,
|
||||
get_prompt_metadata,
|
||||
)
|
||||
from services.extractor.schemas import ExtractionResult, ValidationReport, validate_extraction
|
||||
from services.shared.config import OllamaConfig
|
||||
|
||||
logger = logging.getLogger("ollama_client")
|
||||
|
||||
# Errors that should NOT be retried — the request itself is bad.
|
||||
_NON_RETRYABLE_ERRORS = frozenset({
|
||||
"http_400",
|
||||
"http_401",
|
||||
"http_403",
|
||||
"http_404",
|
||||
"http_422",
|
||||
})
|
||||
|
||||
|
||||
def _is_retryable(error: str | None) -> bool:
|
||||
"""Determine whether an extraction error warrants a retry."""
|
||||
if error is None:
|
||||
return False
|
||||
return error not in _NON_RETRYABLE_ERRORS
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionAttempt:
|
||||
"""Record of a single extraction attempt for audit."""
|
||||
|
||||
raw_output: str = ""
|
||||
validation: ValidationReport | None = None
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
model: str = ""
|
||||
retryable: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResponse:
|
||||
"""Full response from an extraction call, including all attempts."""
|
||||
|
||||
success: bool = False
|
||||
result: ExtractionResult | None = None
|
||||
attempts: list[ExtractionAttempt] = field(default_factory=list)
|
||||
prompt_metadata: dict[str, str] = field(default_factory=dict)
|
||||
model: str = ""
|
||||
total_duration_ms: int = 0
|
||||
|
||||
|
||||
def _compute_backoff(
|
||||
attempt_num: int,
|
||||
base_delay: float,
|
||||
max_delay: float,
|
||||
multiplier: float,
|
||||
) -> float:
|
||||
"""Compute exponential backoff delay for a given attempt number."""
|
||||
delay = base_delay * (multiplier ** attempt_num)
|
||||
return min(delay, max_delay)
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Async client for Ollama structured extraction.
|
||||
|
||||
Usage::
|
||||
|
||||
config = OllamaConfig(base_url="http://localhost:11434", model="llama3.1:8b")
|
||||
client = OllamaClient(config)
|
||||
response = await client.extract(
|
||||
document_text="Apple reported record earnings...",
|
||||
document_type="article",
|
||||
document_id="abc-123",
|
||||
)
|
||||
if response.success:
|
||||
print(response.result)
|
||||
"""
|
||||
|
||||
_config: OllamaConfig
|
||||
_max_retries: int
|
||||
_base_delay: float
|
||||
_max_delay: float
|
||||
_backoff_multiplier: float
|
||||
_owns_client: bool
|
||||
_http: httpx.AsyncClient
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OllamaConfig,
|
||||
max_retries: int | None = None,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
) -> None:
|
||||
self._config = config
|
||||
self._max_retries = max_retries if max_retries is not None else config.max_retries
|
||||
self._base_delay = config.retry_base_delay
|
||||
self._max_delay = config.retry_max_delay
|
||||
self._backoff_multiplier = config.retry_backoff_multiplier
|
||||
self._owns_client = http_client is None
|
||||
self._http = http_client or httpx.AsyncClient(timeout=config.timeout)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying HTTP client if we own it."""
|
||||
if self._owns_client:
|
||||
await self._http.aclose()
|
||||
|
||||
async def extract(
|
||||
self,
|
||||
document_text: str,
|
||||
document_type: str = "article",
|
||||
document_id: str = "",
|
||||
known_tickers: list[str] | None = None,
|
||||
) -> ExtractionResponse:
|
||||
"""Send a document to Ollama for structured intelligence extraction.
|
||||
|
||||
Retries up to ``max_retries`` times when the model returns invalid
|
||||
or incomplete JSON. Uses exponential backoff between retries.
|
||||
Non-retryable errors (e.g. HTTP 400) stop retries immediately.
|
||||
Each attempt and its validation result are preserved for audit.
|
||||
|
||||
Args:
|
||||
document_text: Normalized text content of the document.
|
||||
document_type: One of article, filing, transcript, press_release.
|
||||
document_id: Optional document ID for traceability.
|
||||
known_tickers: Optional ticker hints for the model.
|
||||
|
||||
Returns:
|
||||
An ``ExtractionResponse`` with the parsed result on success.
|
||||
"""
|
||||
prompts = build_extraction_prompt(
|
||||
document_text=document_text,
|
||||
document_type=document_type,
|
||||
document_id=document_id,
|
||||
known_tickers=known_tickers,
|
||||
)
|
||||
json_schema = get_json_schema()
|
||||
prompt_meta = get_prompt_metadata()
|
||||
|
||||
response = ExtractionResponse(
|
||||
prompt_metadata=prompt_meta,
|
||||
model=self._config.model,
|
||||
)
|
||||
|
||||
total_start = time.monotonic()
|
||||
|
||||
for attempt_num in range(self._max_retries + 1):
|
||||
attempt = await self._call_ollama(prompts, json_schema, document_text)
|
||||
response.attempts.append(attempt)
|
||||
|
||||
if attempt.error is None and attempt.validation and attempt.validation.valid:
|
||||
response.success = True
|
||||
response.result = attempt.validation.parsed
|
||||
break
|
||||
|
||||
# Check if the error is non-retryable — stop immediately
|
||||
if not _is_retryable(attempt.error):
|
||||
attempt.retryable = False
|
||||
logger.warning(
|
||||
"Non-retryable error for doc %s: %s — stopping retries",
|
||||
document_id or "unknown",
|
||||
attempt.error,
|
||||
)
|
||||
break
|
||||
|
||||
if attempt_num < self._max_retries:
|
||||
delay = _compute_backoff(
|
||||
attempt_num,
|
||||
self._base_delay,
|
||||
self._max_delay,
|
||||
self._backoff_multiplier,
|
||||
)
|
||||
logger.warning(
|
||||
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||
attempt_num + 1,
|
||||
self._max_retries + 1,
|
||||
document_id or "unknown",
|
||||
attempt.error or "validation failed",
|
||||
delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
|
||||
return response
|
||||
|
||||
async def _call_ollama(
|
||||
self,
|
||||
prompts: dict[str, str],
|
||||
json_schema: dict[str, object],
|
||||
document_text: str = "",
|
||||
) -> ExtractionAttempt:
|
||||
"""Make a single call to the Ollama /api/chat endpoint."""
|
||||
attempt = ExtractionAttempt(model=self._config.model)
|
||||
start = time.monotonic()
|
||||
|
||||
payload = {
|
||||
"model": self._config.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": prompts["system"]},
|
||||
{"role": "user", "content": prompts["user"]},
|
||||
],
|
||||
"format": json_schema,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(
|
||||
f"{self._config.base_url}/api/chat",
|
||||
json=payload,
|
||||
)
|
||||
_ = resp.raise_for_status()
|
||||
except httpx.TimeoutException:
|
||||
attempt.error = "timeout"
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
except httpx.HTTPStatusError as exc:
|
||||
attempt.error = f"http_{exc.response.status_code}"
|
||||
attempt.retryable = _is_retryable(attempt.error)
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
except httpx.HTTPError as exc:
|
||||
attempt.error = f"connection_error: {exc}"
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
return attempt
|
||||
|
||||
attempt.duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# Parse the Ollama response envelope
|
||||
try:
|
||||
body: dict[str, object] = resp.json()
|
||||
except json.JSONDecodeError:
|
||||
attempt.error = "invalid_response_json"
|
||||
attempt.raw_output = resp.text
|
||||
return attempt
|
||||
|
||||
msg = body.get("message")
|
||||
content: str = msg.get("content", "") if isinstance(msg, dict) else ""
|
||||
attempt.raw_output = content
|
||||
|
||||
if not content:
|
||||
attempt.error = "empty_model_response"
|
||||
return attempt
|
||||
|
||||
# Validate against extraction schema
|
||||
attempt.validation = validate_extraction(content, document_text=document_text)
|
||||
if not attempt.validation.valid:
|
||||
attempt.error = "; ".join(attempt.validation.errors)
|
||||
|
||||
return attempt
|
||||
Reference in New Issue
Block a user