Files
stonks-oracle/services/extractor/main.py
T
Celes Renata 117b693b19 feat: add remote vLLM support with provider abstraction layer
- LLMClient Protocol for provider-agnostic inference
- VLLMClient for OpenAI-compatible /v1/chat/completions API
- LLM client factory with provider routing (ollama/vllm)
- VLLMConfig with VLLM_* environment variable loading
- Updated extractor worker with health check and provider switching
- Updated event classifier to use LLMClient protocol
- Helm values for vLLM configuration
- 18 unit tests + 6 property-based tests
- Full backward compatibility preserved
2026-04-23 08:17:23 +00:00

647 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Extractor worker entrypoint - polls Redis for extraction jobs."""
from __future__ import annotations
import asyncio
import json
import logging
import asyncpg
import redis.asyncio as aioredis
from minio import Minio
from services.aggregation.interpolation import (
build_default_profile,
compute_macro_impact_with_sector,
filter_low_confidence_events,
persist_macro_impact_records,
)
from services.extractor.event_classifier import classify_global_event
from services.extractor.llm_factory import build_config_from_resolved, build_llm_client
from services.extractor.vllm_client import check_vllm_health
from services.extractor.worker import persist_extraction
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
from services.shared.config import OllamaConfig, load_config
from services.shared.llm_protocol import LLMClient
from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import (
QUEUE_AGGREGATION,
QUEUE_EXTRACTION,
QUEUE_MACRO_CLASSIFICATION,
queue_key,
)
logger = logging.getLogger("extractor_main")
def _get_provider(resolved: ResolvedAgentConfig | None) -> str:
"""Return the normalised provider string for a resolved config."""
if resolved is None:
return "ollama"
return (resolved.model_provider or "").strip().lower() or "ollama"
def _build_ollama_config_from_resolved(
resolved: ResolvedAgentConfig,
base_config: OllamaConfig,
) -> OllamaConfig:
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings.
Kept for backward compatibility — the factory's ``build_config_from_resolved``
is now the primary path.
"""
return OllamaConfig(
base_url=base_config.base_url,
model=resolved.model_name,
timeout=resolved.timeout_seconds,
max_retries=resolved.max_retries,
retry_base_delay=base_config.retry_base_delay,
retry_max_delay=base_config.retry_max_delay,
retry_backoff_multiplier=base_config.retry_backoff_multiplier,
max_tokens=resolved.max_tokens,
stall_timeout=base_config.stall_timeout,
loop_window=base_config.loop_window,
loop_threshold=base_config.loop_threshold,
context_window=resolved.context_window,
)
async def _check_token_budget(
pool: asyncpg.Pool,
variant_id: str,
token_budget: int,
) -> bool:
"""Check if a variant has exceeded its hourly token budget.
Returns True if the budget is exceeded and invocation should be skipped.
"""
row = await pool.fetchrow(
"""SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total_tokens
FROM agent_performance_log
WHERE variant_id = $1
AND recorded_at >= NOW() - INTERVAL '1 hour'""",
variant_id,
)
used = int(row["total_tokens"]) if row else 0
if used >= token_budget:
logger.warning(
"Token budget exceeded for variant %s: used %d / budget %d — skipping invocation",
variant_id, used, token_budget,
)
return True
return False
async def _log_agent_performance(
pool: asyncpg.Pool,
*,
agent_id: str,
variant_id: str | None = None,
document_id: str = "",
ticker: str = "",
success: bool = False,
duration_ms: int = 0,
confidence: float = 0.0,
retry_count: int = 0,
input_tokens: int = 0,
output_tokens: int = 0,
error_message: str | None = None,
) -> None:
"""Insert a row into agent_performance_log with optional variant attribution."""
try:
await pool.execute(
"""INSERT INTO agent_performance_log
(agent_id, variant_id, document_id, ticker, success, duration_ms,
confidence, retry_count, input_tokens, output_tokens, error_message)
VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8, $9, $10, $11)""",
agent_id,
variant_id,
document_id if document_id else None,
ticker,
success,
duration_ms,
confidence,
retry_count,
input_tokens,
output_tokens,
error_message,
)
except Exception:
logger.warning("Failed to log agent performance", exc_info=True)
async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
"""Build a ticker -> company_id mapping from the companies table."""
rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE")
return {row["ticker"]: str(row["id"]) for row in rows}
async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None:
"""Fetch the document_type for a document."""
row = await pool.fetchrow(
"SELECT document_type FROM documents WHERE id = $1::uuid",
document_id,
)
return row["document_type"] if row else None
async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]:
"""Fetch company info needed for exposure profile loading and interpolation."""
rows = await pool.fetch(
"""SELECT id, ticker, sector, industry, market_cap_bucket
FROM companies WHERE active = TRUE"""
)
return [dict(r) for r in rows]
async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str):
"""Load exposure profile for a company: manual > inferred > default.
Requirements: 4.1
"""
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
# Try manual or inferred profile from DB
row = await pool.fetchrow(
"""SELECT company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version
FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC LIMIT 1""",
company_id,
)
if row:
geo_mix = row["geographic_revenue_mix"]
if isinstance(geo_mix, str):
geo_mix = json.loads(geo_mix)
tier_val = row["market_position_tier"]
try:
tier = MarketPositionTier(tier_val)
except ValueError:
tier = MarketPositionTier.REGIONAL
return ExposureProfileSchema(
company_id=str(row["company_id"]),
geographic_revenue_mix=geo_mix or {},
supply_chain_regions=list(row["supply_chain_regions"] or []),
key_input_commodities=list(row["key_input_commodities"] or []),
regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []),
market_position_tier=tier,
export_dependency_pct=float(row["export_dependency_pct"] or 0.0),
source=row["source"] or "manual",
confidence=float(row["confidence"] or 1.0),
version=row["version"] or 1,
)
# Fall back to default profile
profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap")
profile.company_id = str(company_id)
return profile
async def _compute_and_persist_macro_impacts(
pool: asyncpg.Pool,
event,
companies: list[dict],
confidence_threshold: float = 0.4,
) -> list[str]:
"""Compute MacroImpactRecords for all tracked companies and persist non-zero ones.
Requirements: 4.1, 4.5
"""
# Filter low-confidence events
filtered = filter_low_confidence_events([event], confidence_threshold)
if not filtered:
logger.info("Event %s excluded: confidence %.3f below threshold %.3f",
event.event_id, event.confidence, confidence_threshold)
return []
records = []
for company in companies:
company_id = str(company["id"])
ticker = company["ticker"]
sector = company.get("sector") or ""
industry = company.get("industry") or ""
market_cap_bucket = company.get("market_cap_bucket") or "small_cap"
profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket)
record = compute_macro_impact_with_sector(event, profile, company_sector=sector)
record.ticker = ticker
record.company_id = company_id
if record.macro_impact_score > 0.0:
records.append(record)
if records:
ids = await persist_macro_impact_records(pool, records)
logger.info(
"Persisted %d macro impact records for event %s",
len(ids), event.event_id,
)
return [r.ticker for r in records]
return []
# Track consecutive macro classification failures for alerting (Requirement 10.4)
_macro_consecutive_failures = 0
_MACRO_FAILURE_ALERT_THRESHOLD = 3
async def _process_macro_classification(
*,
pool: asyncpg.Pool,
minio_client: Minio,
ollama: LLMClient,
redis_client: aioredis.Redis,
document_id: str,
text: str,
company_id_map: dict[str, str],
confidence_threshold: float = 0.4,
) -> None:
"""Route a macro_event document to event classification, compute interpolation,
and trigger aggregation for affected tickers.
Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4
"""
global _macro_consecutive_failures
agg_queue = queue_key(QUEUE_AGGREGATION)
try:
event = await classify_global_event(
normalized_text=text,
document_id=document_id,
client=ollama,
pool=pool,
minio_client=minio_client,
)
logger.info(
"Classified macro event %s for doc %s: severity=%s types=%s",
event.event_id, document_id, event.severity, event.event_types,
)
# Reset failure counter on success
_macro_consecutive_failures = 0
# Load all tracked companies and compute macro impacts
companies = await _fetch_company_info(pool)
affected_tickers = await _compute_and_persist_macro_impacts(
pool, event, companies, confidence_threshold,
)
# Trigger aggregation for affected tickers (those with non-zero impact)
enqueued_tickers = set()
for ticker in affected_tickers:
if ticker not in enqueued_tickers:
await redis_client.rpush(
agg_queue,
json.dumps(inject_trace_context({
"ticker": ticker,
"macro_event_id": event.event_id,
})),
)
enqueued_tickers.add(ticker)
logger.info(
"Enqueued aggregation jobs for %d affected tickers after macro event %s",
len(enqueued_tickers), event.event_id,
)
except ValueError as e:
_macro_consecutive_failures += 1
logger.error("Macro event classification failed for doc %s: %s", document_id, e)
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
logger.critical(
"ALERT: Sustained macro classification failures (%d consecutive). "
"Continuing with company-only signals. Operator action required.",
_macro_consecutive_failures,
)
except Exception:
_macro_consecutive_failures += 1
logger.exception("Unexpected error classifying macro event for doc %s", document_id)
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
logger.critical(
"ALERT: Sustained macro classification failures (%d consecutive). "
"Continuing with company-only signals. Operator action required.",
_macro_consecutive_failures,
)
async def main() -> None:
config = load_config()
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
minio_client = Minio(
config.minio.endpoint,
access_key=config.minio.access_key,
secret_key=config.minio.secret_key,
secure=config.minio.secure,
)
# Resolve extractor config from DB (active variant override + TTL cache)
resolver = AgentConfigResolver(pool, ttl_seconds=60)
resolved_config: ResolvedAgentConfig | None = None
extractor_provider = "ollama"
try:
resolved_config = await resolver.resolve("document-extractor")
if resolved_config is not None:
extractor_provider = _get_provider(resolved_config)
logger.info(
"Extractor using resolved config: model=%s variant=%s provider=%s",
resolved_config.model_name, resolved_config.variant_id, extractor_provider,
)
else:
logger.info("No DB config for document-extractor — using env defaults")
except Exception:
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
# vLLM health check at startup when provider is vllm (Requirement 7.17.3)
if extractor_provider == "vllm":
healthy = await check_vllm_health(config.vllm.base_url)
if not healthy:
logger.warning(
"vLLM health check failed at startup — falling back to Ollama for extractor",
)
extractor_provider = "ollama"
# Override resolved config provider so factory builds OllamaClient
resolved_config = None
extractor_client: LLMClient = build_llm_client(
resolved_config, config.ollama, config.vllm,
)
# Resolve event classifier config separately (may use different model)
classifier_resolved: ResolvedAgentConfig | None = None
classifier_provider = "ollama"
try:
classifier_resolved = await resolver.resolve("event-classifier")
if classifier_resolved is not None:
classifier_provider = _get_provider(classifier_resolved)
logger.info(
"Event classifier using resolved config: model=%s variant=%s provider=%s",
classifier_resolved.model_name, classifier_resolved.variant_id, classifier_provider,
)
else:
logger.info("No DB config for event-classifier — using extractor config")
except Exception:
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
# vLLM health check for classifier if it uses vllm and extractor didn't already check
if classifier_provider == "vllm" and extractor_provider != "vllm":
healthy = await check_vllm_health(config.vllm.base_url)
if not healthy:
logger.warning(
"vLLM health check failed at startup — falling back to Ollama for classifier",
)
classifier_provider = "ollama"
classifier_resolved = None
# Build classifier client — share with extractor when configs match
classifier_client: LLMClient
if classifier_resolved is not None or classifier_provider != extractor_provider:
classifier_client = build_llm_client(
classifier_resolved, config.ollama, config.vllm,
)
else:
classifier_client = extractor_client
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION)
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
agg_queue = queue_key(QUEUE_AGGREGATION)
confidence_threshold = config.macro.macro_confidence_threshold
logger.info("Extractor worker started, polling %s and %s", queue, macro_queue)
# Pre-load company ID map (refreshed periodically)
company_id_map = await _build_company_id_map(pool)
refresh_counter = 0
# Alternate between queues to prevent starvation: process 1 macro then 2 extractions
macro_turn_counter = 0
try:
while True:
# Alternate: every 3rd job from macro queue, rest from extraction
# This prevents macro events from starving regular extractions
raw = None
is_macro_job = False
if macro_turn_counter % 3 == 0:
# Try macro first
raw = await redis_client.lpop(macro_queue)
is_macro_job = raw is not None
if raw is None:
raw = await redis_client.lpop(queue)
else:
# Try extraction first
raw = await redis_client.lpop(queue)
if raw is None:
raw = await redis_client.lpop(macro_queue)
is_macro_job = raw is not None
macro_turn_counter += 1
if raw is None:
await asyncio.sleep(1)
continue
job = json.loads(raw)
document_id = job.get("document_id", "")
ticker = job.get("ticker", "")
text = job.get("text", "") or job.get("normalized_text", "")
# If no text in job, try to fetch from MinIO via the document's normalized_storage_ref
if not text:
ref_row = await pool.fetchrow(
"SELECT normalized_storage_ref FROM documents WHERE id = $1::uuid",
document_id,
)
if ref_row and ref_row["normalized_storage_ref"]:
try:
ref = ref_row["normalized_storage_ref"]
# ref format: s3://bucket/path
parts = ref.replace("s3://", "").split("/", 1)
if len(parts) == 2:
obj = minio_client.get_object(parts[0], parts[1])
text = obj.read().decode("utf-8")
obj.close()
obj.release_conn()
except Exception as e:
logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e)
# Refresh company map every 100 jobs
refresh_counter += 1
if refresh_counter % 100 == 0:
company_id_map = await _build_company_id_map(pool)
# Re-resolve extractor config (picks up active variant swaps)
try:
new_resolved = await resolver.resolve("document-extractor")
if new_resolved is not None:
new_provider = _get_provider(new_resolved)
new_cfg = build_config_from_resolved(
new_resolved, config.ollama, config.vllm,
)
old_provider = extractor_provider
provider_changed = new_provider != extractor_provider
model_changed = new_cfg.model != extractor_client._config.model
if provider_changed or model_changed:
logger.info(
"Extractor provider switch: old_provider=%s new_provider=%s "
"model=%s variant=%s",
old_provider, new_provider,
new_resolved.model_name, new_resolved.variant_id,
)
await extractor_client.close()
extractor_client = build_llm_client(
new_resolved, config.ollama, config.vllm,
)
extractor_provider = new_provider
else:
# Same provider and model — just update config in-place
extractor_client._config = new_cfg # type: ignore[assignment]
resolved_config = new_resolved
except Exception:
logger.warning("Failed to refresh extractor config", exc_info=True)
# Re-resolve event classifier config
try:
new_cls_resolved = await resolver.resolve("event-classifier")
if new_cls_resolved is not None:
new_cls_provider = _get_provider(new_cls_resolved)
new_cls_cfg = build_config_from_resolved(
new_cls_resolved, config.ollama, config.vllm,
)
old_cls_provider = classifier_provider
cls_provider_changed = new_cls_provider != classifier_provider
cls_model_changed = new_cls_cfg.model != classifier_client._config.model
if cls_provider_changed or cls_model_changed:
logger.info(
"Classifier provider switch: old_provider=%s new_provider=%s "
"model=%s variant=%s",
old_cls_provider, new_cls_provider,
new_cls_resolved.model_name, new_cls_resolved.variant_id,
)
if classifier_client is not extractor_client:
await classifier_client.close()
classifier_client = build_llm_client(
new_cls_resolved, config.ollama, config.vllm,
)
classifier_provider = new_cls_provider
elif classifier_client is extractor_client and new_cls_cfg.model != extractor_client._config.model:
classifier_client = build_llm_client(
new_cls_resolved, config.ollama, config.vllm,
)
classifier_provider = new_cls_provider
classifier_resolved = new_cls_resolved
except Exception:
logger.warning("Failed to refresh event-classifier config", exc_info=True)
# Route macro_event documents to event classification (Requirement 2.1)
doc_type = None
if is_macro_job:
doc_type = "macro_event"
else:
doc_type = await _fetch_document_type(pool, document_id)
if doc_type == "macro_event":
logger.info("Routing macro_event doc %s to event classifier", document_id)
await _process_macro_classification(
pool=pool,
minio_client=minio_client,
ollama=classifier_client,
redis_client=redis_client,
document_id=document_id,
text=text,
company_id_map=company_id_map,
confidence_threshold=confidence_threshold,
)
continue
# Standard extraction pipeline for non-macro documents
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
try:
# Token budget enforcement (Requirement 10.6)
if (
resolved_config is not None
and resolved_config.token_budget > 0
and resolved_config.variant_id is not None
):
budget_exceeded = await _check_token_budget(
pool, resolved_config.variant_id, resolved_config.token_budget,
)
if budget_exceeded:
continue
# Input token limit truncation (Requirement 10.5)
extraction_text = text
if resolved_config is not None and resolved_config.input_token_limit > 0:
# Rough estimate: ~4 chars per token
max_chars = resolved_config.input_token_limit * 4
if len(extraction_text) > max_chars:
extraction_text = extraction_text[:max_chars]
logger.info(
"Truncated input for doc %s from %d to %d chars (token limit %d)",
document_id, len(text), max_chars, resolved_config.input_token_limit,
)
# Pass all tracked tickers so the model can identify any mentioned companies
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
extraction_response = await extractor_client.extract(
extraction_text,
document_id=document_id,
known_tickers=all_tickers,
)
result = await persist_extraction(
pool=pool,
minio_client=minio_client,
document_id=document_id,
ticker=ticker,
extraction_response=extraction_response,
company_id_map=company_id_map,
document_text_length=len(extraction_text),
)
# Log to agent_performance_log with variant attribution
if resolved_config is not None:
output_tokens = 0
if extraction_response.attempts:
final = extraction_response.attempts[-1]
output_tokens = len(final.raw_output) // 4 if final.raw_output else 0
await _log_agent_performance(
pool,
agent_id=resolved_config.agent_id,
variant_id=resolved_config.variant_id,
document_id=document_id,
ticker=ticker,
success=extraction_response.success,
duration_ms=extraction_response.total_duration_ms,
confidence=extraction_response.result.confidence if extraction_response.result else 0.0,
retry_count=max(0, len(extraction_response.attempts) - 1),
input_tokens=len(extraction_text) // 4,
output_tokens=output_tokens,
error_message=(
extraction_response.attempts[-1].error
if extraction_response.attempts and extraction_response.attempts[-1].error
else None
),
)
# Enqueue aggregation job for the ticker on success
if result.success and ticker:
await redis_client.rpush(
agg_queue,
json.dumps(inject_trace_context({"ticker": ticker})),
)
except Exception:
logger.exception("Extraction failed for doc %s", document_id)
finally:
await pool.close()
await redis_client.close()
if __name__ == "__main__":
asyncio.run(main())