fix: add extract() method to VLLMClient for extraction pipeline compatibility

This commit is contained in:
Celes Renata
2026-04-23 19:32:33 +00:00
parent 4bee7a7874
commit f7ae34ef3b
+75
View File
@@ -9,6 +9,7 @@ Requirements: 2.12.10, 7.17.4
"""
from __future__ import annotations
import asyncio
import logging
import time
@@ -16,10 +17,17 @@ import httpx
from services.extractor.client import (
ExtractionAttempt,
ExtractionResponse,
_compute_backoff,
_is_retryable,
_repair_json,
_strip_markdown_fences,
)
from services.extractor.prompts import (
build_extraction_prompt,
get_json_schema,
get_prompt_metadata,
)
from services.extractor.schemas import validate_extraction
from services.shared.config import VLLMConfig
@@ -146,6 +154,73 @@ class VLLMClient:
return attempt
async def extract(
self,
document_text: str,
document_type: str = "article",
document_id: str = "",
known_tickers: list[str] | None = None,
) -> ExtractionResponse:
"""Send a document to vLLM for structured intelligence extraction.
Mirrors ``OllamaClient.extract()`` — retries up to ``max_retries``
times with exponential backoff, preserving each attempt for audit.
"""
prompts = build_extraction_prompt(
document_text=document_text,
document_type=document_type,
document_id=document_id,
known_tickers=known_tickers,
)
json_schema = get_json_schema()
prompt_meta = get_prompt_metadata()
response = ExtractionResponse(
prompt_metadata=prompt_meta,
model=self._config.model,
)
max_retries = self._config.max_retries
total_start = time.monotonic()
for attempt_num in range(max_retries + 1):
attempt = await self.call_llm(prompts, json_schema, document_text)
response.attempts.append(attempt)
if attempt.error is None and attempt.validation and attempt.validation.valid:
response.success = True
response.result = attempt.validation.parsed
break
if not _is_retryable(attempt.error):
attempt.retryable = False
logger.warning(
"Non-retryable error for doc %s: %s — stopping retries",
document_id or "unknown",
attempt.error,
)
break
if attempt_num < max_retries:
delay = _compute_backoff(
attempt_num,
self._config.retry_base_delay,
self._config.retry_max_delay,
self._config.retry_backoff_multiplier,
)
logger.warning(
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
attempt_num + 1,
max_retries + 1,
document_id or "unknown",
attempt.error or "validation failed",
delay,
)
await asyncio.sleep(delay)
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
return response
async def close(self) -> None:
"""Release the underlying ``httpx.AsyncClient`` if we own it."""
if self._owns_client: