fix: add extract() method to VLLMClient for extraction pipeline compatibility
This commit is contained in:
@@ -9,6 +9,7 @@ Requirements: 2.1–2.10, 7.1–7.4
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -16,10 +17,17 @@ import httpx
|
|||||||
|
|
||||||
from services.extractor.client import (
|
from services.extractor.client import (
|
||||||
ExtractionAttempt,
|
ExtractionAttempt,
|
||||||
|
ExtractionResponse,
|
||||||
|
_compute_backoff,
|
||||||
_is_retryable,
|
_is_retryable,
|
||||||
_repair_json,
|
_repair_json,
|
||||||
_strip_markdown_fences,
|
_strip_markdown_fences,
|
||||||
)
|
)
|
||||||
|
from services.extractor.prompts import (
|
||||||
|
build_extraction_prompt,
|
||||||
|
get_json_schema,
|
||||||
|
get_prompt_metadata,
|
||||||
|
)
|
||||||
from services.extractor.schemas import validate_extraction
|
from services.extractor.schemas import validate_extraction
|
||||||
from services.shared.config import VLLMConfig
|
from services.shared.config import VLLMConfig
|
||||||
|
|
||||||
@@ -146,6 +154,73 @@ class VLLMClient:
|
|||||||
|
|
||||||
return attempt
|
return attempt
|
||||||
|
|
||||||
|
async def extract(
|
||||||
|
self,
|
||||||
|
document_text: str,
|
||||||
|
document_type: str = "article",
|
||||||
|
document_id: str = "",
|
||||||
|
known_tickers: list[str] | None = None,
|
||||||
|
) -> ExtractionResponse:
|
||||||
|
"""Send a document to vLLM for structured intelligence extraction.
|
||||||
|
|
||||||
|
Mirrors ``OllamaClient.extract()`` — retries up to ``max_retries``
|
||||||
|
times with exponential backoff, preserving each attempt for audit.
|
||||||
|
"""
|
||||||
|
prompts = build_extraction_prompt(
|
||||||
|
document_text=document_text,
|
||||||
|
document_type=document_type,
|
||||||
|
document_id=document_id,
|
||||||
|
known_tickers=known_tickers,
|
||||||
|
)
|
||||||
|
json_schema = get_json_schema()
|
||||||
|
prompt_meta = get_prompt_metadata()
|
||||||
|
|
||||||
|
response = ExtractionResponse(
|
||||||
|
prompt_metadata=prompt_meta,
|
||||||
|
model=self._config.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_retries = self._config.max_retries
|
||||||
|
total_start = time.monotonic()
|
||||||
|
|
||||||
|
for attempt_num in range(max_retries + 1):
|
||||||
|
attempt = await self.call_llm(prompts, json_schema, document_text)
|
||||||
|
response.attempts.append(attempt)
|
||||||
|
|
||||||
|
if attempt.error is None and attempt.validation and attempt.validation.valid:
|
||||||
|
response.success = True
|
||||||
|
response.result = attempt.validation.parsed
|
||||||
|
break
|
||||||
|
|
||||||
|
if not _is_retryable(attempt.error):
|
||||||
|
attempt.retryable = False
|
||||||
|
logger.warning(
|
||||||
|
"Non-retryable error for doc %s: %s — stopping retries",
|
||||||
|
document_id or "unknown",
|
||||||
|
attempt.error,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if attempt_num < max_retries:
|
||||||
|
delay = _compute_backoff(
|
||||||
|
attempt_num,
|
||||||
|
self._config.retry_base_delay,
|
||||||
|
self._config.retry_max_delay,
|
||||||
|
self._config.retry_backoff_multiplier,
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"Extraction attempt %d/%d failed for doc %s: %s — retrying in %.1fs",
|
||||||
|
attempt_num + 1,
|
||||||
|
max_retries + 1,
|
||||||
|
document_id or "unknown",
|
||||||
|
attempt.error or "validation failed",
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
response.total_duration_ms = int((time.monotonic() - total_start) * 1000)
|
||||||
|
return response
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Release the underlying ``httpx.AsyncClient`` if we own it."""
|
"""Release the underlying ``httpx.AsyncClient`` if we own it."""
|
||||||
if self._owns_client:
|
if self._owns_client:
|
||||||
|
|||||||
Reference in New Issue
Block a user