From f7ae34ef3b251b60a4ee36d2e5f342c5eafe78d1 Mon Sep 17 00:00:00 2001 From: Celes Renata Date: Thu, 23 Apr 2026 19:32:33 +0000 Subject: [PATCH] fix: add extract() method to VLLMClient for extraction pipeline compatibility --- services/extractor/vllm_client.py | 75 +++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/services/extractor/vllm_client.py b/services/extractor/vllm_client.py index 7cba5ae..f68f182 100644 --- a/services/extractor/vllm_client.py +++ b/services/extractor/vllm_client.py @@ -9,6 +9,7 @@ Requirements: 2.1–2.10, 7.1–7.4 """ from __future__ import annotations +import asyncio import logging import time @@ -16,10 +17,17 @@ import httpx from services.extractor.client import ( ExtractionAttempt, + ExtractionResponse, + _compute_backoff, _is_retryable, _repair_json, _strip_markdown_fences, ) +from services.extractor.prompts import ( + build_extraction_prompt, + get_json_schema, + get_prompt_metadata, +) from services.extractor.schemas import validate_extraction from services.shared.config import VLLMConfig @@ -146,6 +154,73 @@ class VLLMClient: return attempt + async def extract( + self, + document_text: str, + document_type: str = "article", + document_id: str = "", + known_tickers: list[str] | None = None, + ) -> ExtractionResponse: + """Send a document to vLLM for structured intelligence extraction. + + Mirrors ``OllamaClient.extract()`` — retries up to ``max_retries`` + times with exponential backoff, preserving each attempt for audit. + """ + prompts = build_extraction_prompt( + document_text=document_text, + document_type=document_type, + document_id=document_id, + known_tickers=known_tickers, + ) + json_schema = get_json_schema() + prompt_meta = get_prompt_metadata() + + response = ExtractionResponse( + prompt_metadata=prompt_meta, + model=self._config.model, + ) + + max_retries = self._config.max_retries + total_start = time.monotonic() + + for attempt_num in range(max_retries + 1): + attempt = await self.call_llm(prompts, json_schema, document_text) + response.attempts.append(attempt) + + if attempt.error is None and attempt.validation and attempt.validation.valid: + response.success = True + response.result = attempt.validation.parsed + break + + if not _is_retryable(attempt.error): + attempt.retryable = False + logger.warning( + "Non-retryable error for doc %s: %s — stopping retries", + document_id or "unknown", + attempt.error, + ) + break + + if attempt_num < max_retries: + delay = _compute_backoff( + attempt_num, + self._config.retry_base_delay, + self._config.retry_max_delay, + self._config.retry_backoff_multiplier, + ) + logger.warning( + "Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs", + attempt_num + 1, + max_retries + 1, + document_id or "unknown", + attempt.error or "validation failed", + delay, + ) + await asyncio.sleep(delay) + + response.total_duration_ms = int((time.monotonic() - total_start) * 1000) + return response + async def close(self) -> None: """Release the underlying ``httpx.AsyncClient`` if we own it.""" if self._owns_client: