b38fb24f14
ci/woodpecker/push/test Pipeline was successful
ci/woodpecker/push/build-3 Pipeline was successful
ci/woodpecker/push/build-2 Pipeline was successful
ci/woodpecker/push/build-1 Pipeline was successful
ci/woodpecker/push/finalize Pipeline was successful
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
- Migration 026: update seed defaults from ollama to vllm/AxionML - Migration 031: fix existing rows still on old ollama defaults - Helm values: set OLLAMA_BASE_URL to cluster ollama endpoint (was empty) - Extractor: guard against switching to ollama when base_url is empty - OllamaClient: validate base_url on construction to fail fast
663 lines
28 KiB
Python
663 lines
28 KiB
Python
"""Extractor worker entrypoint - polls Redis for extraction jobs."""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
|
||
import asyncpg
|
||
import redis.asyncio as aioredis
|
||
from minio import Minio
|
||
|
||
from services.aggregation.interpolation import (
|
||
build_default_profile,
|
||
compute_macro_impact_with_sector,
|
||
filter_low_confidence_events,
|
||
persist_macro_impact_records,
|
||
)
|
||
from services.extractor.event_classifier import classify_global_event
|
||
from services.extractor.llm_factory import build_config_from_resolved, build_llm_client
|
||
from services.extractor.vllm_client import check_vllm_health
|
||
from services.extractor.worker import persist_extraction
|
||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||
from services.shared.config import OllamaConfig, load_config
|
||
from services.shared.llm_protocol import LLMClient
|
||
from services.shared.logging import inject_trace_context, setup_logging
|
||
from services.shared.redis_keys import (
|
||
QUEUE_AGGREGATION,
|
||
QUEUE_EXTRACTION,
|
||
QUEUE_MACRO_CLASSIFICATION,
|
||
queue_key,
|
||
)
|
||
|
||
logger = logging.getLogger("extractor_main")
|
||
|
||
|
||
def _get_provider(resolved: ResolvedAgentConfig | None) -> str:
|
||
"""Return the normalised provider string for a resolved config."""
|
||
if resolved is None:
|
||
return "ollama"
|
||
return (resolved.model_provider or "").strip().lower() or "ollama"
|
||
|
||
|
||
def _build_ollama_config_from_resolved(
|
||
resolved: ResolvedAgentConfig,
|
||
base_config: OllamaConfig,
|
||
) -> OllamaConfig:
|
||
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings.
|
||
|
||
Kept for backward compatibility — the factory's ``build_config_from_resolved``
|
||
is now the primary path.
|
||
"""
|
||
return OllamaConfig(
|
||
base_url=base_config.base_url,
|
||
model=resolved.model_name,
|
||
timeout=resolved.timeout_seconds,
|
||
max_retries=resolved.max_retries,
|
||
retry_base_delay=base_config.retry_base_delay,
|
||
retry_max_delay=base_config.retry_max_delay,
|
||
retry_backoff_multiplier=base_config.retry_backoff_multiplier,
|
||
max_tokens=resolved.max_tokens,
|
||
stall_timeout=base_config.stall_timeout,
|
||
loop_window=base_config.loop_window,
|
||
loop_threshold=base_config.loop_threshold,
|
||
context_window=resolved.context_window,
|
||
)
|
||
|
||
|
||
async def _check_token_budget(
|
||
pool: asyncpg.Pool,
|
||
variant_id: str,
|
||
token_budget: int,
|
||
) -> bool:
|
||
"""Check if a variant has exceeded its hourly token budget.
|
||
|
||
Returns True if the budget is exceeded and invocation should be skipped.
|
||
"""
|
||
row = await pool.fetchrow(
|
||
"""SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total_tokens
|
||
FROM agent_performance_log
|
||
WHERE variant_id = $1
|
||
AND recorded_at >= NOW() - INTERVAL '1 hour'""",
|
||
variant_id,
|
||
)
|
||
used = int(row["total_tokens"]) if row else 0
|
||
if used >= token_budget:
|
||
logger.warning(
|
||
"Token budget exceeded for variant %s: used %d / budget %d — skipping invocation",
|
||
variant_id, used, token_budget,
|
||
)
|
||
return True
|
||
return False
|
||
|
||
|
||
async def _log_agent_performance(
|
||
pool: asyncpg.Pool,
|
||
*,
|
||
agent_id: str,
|
||
variant_id: str | None = None,
|
||
document_id: str = "",
|
||
ticker: str = "",
|
||
success: bool = False,
|
||
duration_ms: int = 0,
|
||
confidence: float = 0.0,
|
||
retry_count: int = 0,
|
||
input_tokens: int = 0,
|
||
output_tokens: int = 0,
|
||
error_message: str | None = None,
|
||
) -> None:
|
||
"""Insert a row into agent_performance_log with optional variant attribution."""
|
||
try:
|
||
await pool.execute(
|
||
"""INSERT INTO agent_performance_log
|
||
(agent_id, variant_id, document_id, ticker, success, duration_ms,
|
||
confidence, retry_count, input_tokens, output_tokens, error_message)
|
||
VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8, $9, $10, $11)""",
|
||
agent_id,
|
||
variant_id,
|
||
document_id if document_id else None,
|
||
ticker,
|
||
success,
|
||
duration_ms,
|
||
confidence,
|
||
retry_count,
|
||
input_tokens,
|
||
output_tokens,
|
||
error_message,
|
||
)
|
||
except Exception:
|
||
logger.warning("Failed to log agent performance", exc_info=True)
|
||
|
||
|
||
async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
|
||
"""Build a ticker -> company_id mapping from the companies table."""
|
||
rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE")
|
||
return {row["ticker"]: str(row["id"]) for row in rows}
|
||
|
||
|
||
async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None:
|
||
"""Fetch the document_type for a document."""
|
||
row = await pool.fetchrow(
|
||
"SELECT document_type FROM documents WHERE id = $1::uuid",
|
||
document_id,
|
||
)
|
||
return row["document_type"] if row else None
|
||
|
||
|
||
async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]:
|
||
"""Fetch company info needed for exposure profile loading and interpolation."""
|
||
rows = await pool.fetch(
|
||
"""SELECT id, ticker, sector, industry, market_cap_bucket
|
||
FROM companies WHERE active = TRUE"""
|
||
)
|
||
return [dict(r) for r in rows]
|
||
|
||
|
||
async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str):
|
||
"""Load exposure profile for a company: manual > inferred > default.
|
||
|
||
Requirements: 4.1
|
||
"""
|
||
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
|
||
|
||
# Try manual or inferred profile from DB
|
||
row = await pool.fetchrow(
|
||
"""SELECT company_id, geographic_revenue_mix, supply_chain_regions,
|
||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||
export_dependency_pct, source, confidence, version
|
||
FROM exposure_profiles
|
||
WHERE company_id = $1 AND active = TRUE
|
||
ORDER BY version DESC LIMIT 1""",
|
||
company_id,
|
||
)
|
||
if row:
|
||
geo_mix = row["geographic_revenue_mix"]
|
||
if isinstance(geo_mix, str):
|
||
geo_mix = json.loads(geo_mix)
|
||
tier_val = row["market_position_tier"]
|
||
try:
|
||
tier = MarketPositionTier(tier_val)
|
||
except ValueError:
|
||
tier = MarketPositionTier.REGIONAL
|
||
return ExposureProfileSchema(
|
||
company_id=str(row["company_id"]),
|
||
geographic_revenue_mix=geo_mix or {},
|
||
supply_chain_regions=list(row["supply_chain_regions"] or []),
|
||
key_input_commodities=list(row["key_input_commodities"] or []),
|
||
regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []),
|
||
market_position_tier=tier,
|
||
export_dependency_pct=float(row["export_dependency_pct"] or 0.0),
|
||
source=row["source"] or "manual",
|
||
confidence=float(row["confidence"] or 1.0),
|
||
version=row["version"] or 1,
|
||
)
|
||
|
||
# Fall back to default profile
|
||
profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap")
|
||
profile.company_id = str(company_id)
|
||
return profile
|
||
|
||
|
||
async def _compute_and_persist_macro_impacts(
|
||
pool: asyncpg.Pool,
|
||
event,
|
||
companies: list[dict],
|
||
confidence_threshold: float = 0.4,
|
||
) -> list[str]:
|
||
"""Compute MacroImpactRecords for all tracked companies and persist non-zero ones.
|
||
|
||
Requirements: 4.1, 4.5
|
||
"""
|
||
# Filter low-confidence events
|
||
filtered = filter_low_confidence_events([event], confidence_threshold)
|
||
if not filtered:
|
||
logger.info("Event %s excluded: confidence %.3f below threshold %.3f",
|
||
event.event_id, event.confidence, confidence_threshold)
|
||
return []
|
||
|
||
records = []
|
||
for company in companies:
|
||
company_id = str(company["id"])
|
||
ticker = company["ticker"]
|
||
sector = company.get("sector") or ""
|
||
industry = company.get("industry") or ""
|
||
market_cap_bucket = company.get("market_cap_bucket") or "small_cap"
|
||
|
||
profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket)
|
||
|
||
record = compute_macro_impact_with_sector(event, profile, company_sector=sector)
|
||
record.ticker = ticker
|
||
record.company_id = company_id
|
||
|
||
if record.macro_impact_score > 0.0:
|
||
records.append(record)
|
||
|
||
if records:
|
||
ids = await persist_macro_impact_records(pool, records)
|
||
logger.info(
|
||
"Persisted %d macro impact records for event %s",
|
||
len(ids), event.event_id,
|
||
)
|
||
return [r.ticker for r in records]
|
||
|
||
return []
|
||
|
||
|
||
# Track consecutive macro classification failures for alerting (Requirement 10.4)
|
||
_macro_consecutive_failures = 0
|
||
_MACRO_FAILURE_ALERT_THRESHOLD = 3
|
||
|
||
|
||
async def _process_macro_classification(
|
||
*,
|
||
pool: asyncpg.Pool,
|
||
minio_client: Minio,
|
||
ollama: LLMClient,
|
||
redis_client: aioredis.Redis,
|
||
document_id: str,
|
||
text: str,
|
||
company_id_map: dict[str, str],
|
||
confidence_threshold: float = 0.4,
|
||
) -> None:
|
||
"""Route a macro_event document to event classification, compute interpolation,
|
||
and trigger aggregation for affected tickers.
|
||
|
||
Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4
|
||
"""
|
||
global _macro_consecutive_failures
|
||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||
|
||
try:
|
||
event = await classify_global_event(
|
||
normalized_text=text,
|
||
document_id=document_id,
|
||
client=ollama,
|
||
pool=pool,
|
||
minio_client=minio_client,
|
||
)
|
||
logger.info(
|
||
"Classified macro event %s for doc %s: severity=%s types=%s",
|
||
event.event_id, document_id, event.severity, event.event_types,
|
||
)
|
||
|
||
# Reset failure counter on success
|
||
_macro_consecutive_failures = 0
|
||
|
||
# Load all tracked companies and compute macro impacts
|
||
companies = await _fetch_company_info(pool)
|
||
affected_tickers = await _compute_and_persist_macro_impacts(
|
||
pool, event, companies, confidence_threshold,
|
||
)
|
||
|
||
# Trigger aggregation for affected tickers (those with non-zero impact)
|
||
enqueued_tickers = set()
|
||
for ticker in affected_tickers:
|
||
if ticker not in enqueued_tickers:
|
||
await redis_client.rpush(
|
||
agg_queue,
|
||
json.dumps(inject_trace_context({
|
||
"ticker": ticker,
|
||
"macro_event_id": event.event_id,
|
||
})),
|
||
)
|
||
enqueued_tickers.add(ticker)
|
||
|
||
logger.info(
|
||
"Enqueued aggregation jobs for %d affected tickers after macro event %s",
|
||
len(enqueued_tickers), event.event_id,
|
||
)
|
||
|
||
except ValueError as e:
|
||
_macro_consecutive_failures += 1
|
||
logger.error("Macro event classification failed for doc %s: %s", document_id, e)
|
||
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
|
||
logger.critical(
|
||
"ALERT: Sustained macro classification failures (%d consecutive). "
|
||
"Continuing with company-only signals. Operator action required.",
|
||
_macro_consecutive_failures,
|
||
)
|
||
except Exception:
|
||
_macro_consecutive_failures += 1
|
||
logger.exception("Unexpected error classifying macro event for doc %s", document_id)
|
||
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
|
||
logger.critical(
|
||
"ALERT: Sustained macro classification failures (%d consecutive). "
|
||
"Continuing with company-only signals. Operator action required.",
|
||
_macro_consecutive_failures,
|
||
)
|
||
|
||
|
||
async def main() -> None:
|
||
config = load_config()
|
||
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
|
||
|
||
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
|
||
minio_client = Minio(
|
||
config.minio.endpoint,
|
||
access_key=config.minio.access_key,
|
||
secret_key=config.minio.secret_key,
|
||
secure=config.minio.secure,
|
||
)
|
||
|
||
# Resolve extractor config from DB (active variant override + TTL cache)
|
||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||
resolved_config: ResolvedAgentConfig | None = None
|
||
extractor_provider = "ollama"
|
||
try:
|
||
resolved_config = await resolver.resolve("document-extractor")
|
||
if resolved_config is not None:
|
||
extractor_provider = _get_provider(resolved_config)
|
||
logger.info(
|
||
"Extractor using resolved config: model=%s variant=%s provider=%s",
|
||
resolved_config.model_name, resolved_config.variant_id, extractor_provider,
|
||
)
|
||
else:
|
||
logger.info("No DB config for document-extractor — using env defaults")
|
||
except Exception:
|
||
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
|
||
|
||
# vLLM health check at startup when provider is vllm (Requirement 7.1–7.3)
|
||
if extractor_provider == "vllm":
|
||
healthy = await check_vllm_health(config.vllm.base_url)
|
||
if not healthy:
|
||
logger.warning(
|
||
"vLLM health check failed at startup — falling back to Ollama for extractor",
|
||
)
|
||
extractor_provider = "ollama"
|
||
# Override resolved config provider so factory builds OllamaClient
|
||
resolved_config = None
|
||
|
||
extractor_client: LLMClient = build_llm_client(
|
||
resolved_config, config.ollama, config.vllm,
|
||
)
|
||
|
||
# Resolve event classifier config separately (may use different model)
|
||
classifier_resolved: ResolvedAgentConfig | None = None
|
||
classifier_provider = "ollama"
|
||
try:
|
||
classifier_resolved = await resolver.resolve("event-classifier")
|
||
if classifier_resolved is not None:
|
||
classifier_provider = _get_provider(classifier_resolved)
|
||
logger.info(
|
||
"Event classifier using resolved config: model=%s variant=%s provider=%s",
|
||
classifier_resolved.model_name, classifier_resolved.variant_id, classifier_provider,
|
||
)
|
||
else:
|
||
logger.info("No DB config for event-classifier — using extractor config")
|
||
except Exception:
|
||
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
|
||
|
||
# vLLM health check for classifier if it uses vllm and extractor didn't already check
|
||
if classifier_provider == "vllm" and extractor_provider != "vllm":
|
||
healthy = await check_vllm_health(config.vllm.base_url)
|
||
if not healthy:
|
||
logger.warning(
|
||
"vLLM health check failed at startup — falling back to Ollama for classifier",
|
||
)
|
||
classifier_provider = "ollama"
|
||
classifier_resolved = None
|
||
|
||
# Build classifier client — share with extractor when configs match
|
||
classifier_client: LLMClient
|
||
if classifier_resolved is not None or classifier_provider != extractor_provider:
|
||
classifier_client = build_llm_client(
|
||
classifier_resolved, config.ollama, config.vllm,
|
||
)
|
||
else:
|
||
classifier_client = extractor_client
|
||
|
||
redis_client = aioredis.from_url(config.redis.url)
|
||
queue = queue_key(QUEUE_EXTRACTION)
|
||
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||
confidence_threshold = config.macro.macro_confidence_threshold
|
||
logger.info("Extractor worker started, polling %s and %s", queue, macro_queue)
|
||
|
||
# Pre-load company ID map (refreshed periodically)
|
||
company_id_map = await _build_company_id_map(pool)
|
||
refresh_counter = 0
|
||
# Alternate between queues to prevent starvation: process 1 macro then 2 extractions
|
||
macro_turn_counter = 0
|
||
|
||
try:
|
||
while True:
|
||
# Alternate: every 3rd job from macro queue, rest from extraction
|
||
# This prevents macro events from starving regular extractions
|
||
raw = None
|
||
is_macro_job = False
|
||
|
||
if macro_turn_counter % 3 == 0:
|
||
# Try macro first
|
||
raw = await redis_client.lpop(macro_queue)
|
||
is_macro_job = raw is not None
|
||
if raw is None:
|
||
raw = await redis_client.lpop(queue)
|
||
else:
|
||
# Try extraction first
|
||
raw = await redis_client.lpop(queue)
|
||
if raw is None:
|
||
raw = await redis_client.lpop(macro_queue)
|
||
is_macro_job = raw is not None
|
||
|
||
macro_turn_counter += 1
|
||
|
||
if raw is None:
|
||
await asyncio.sleep(1)
|
||
continue
|
||
|
||
job = json.loads(raw)
|
||
document_id = job.get("document_id", "")
|
||
ticker = job.get("ticker", "")
|
||
text = job.get("text", "") or job.get("normalized_text", "")
|
||
|
||
# If no text in job, try to fetch from MinIO via the document's normalized_storage_ref
|
||
if not text:
|
||
ref_row = await pool.fetchrow(
|
||
"SELECT normalized_storage_ref FROM documents WHERE id = $1::uuid",
|
||
document_id,
|
||
)
|
||
if ref_row and ref_row["normalized_storage_ref"]:
|
||
try:
|
||
ref = ref_row["normalized_storage_ref"]
|
||
# ref format: s3://bucket/path
|
||
parts = ref.replace("s3://", "").split("/", 1)
|
||
if len(parts) == 2:
|
||
obj = minio_client.get_object(parts[0], parts[1])
|
||
text = obj.read().decode("utf-8")
|
||
obj.close()
|
||
obj.release_conn()
|
||
except Exception as e:
|
||
logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e)
|
||
|
||
# Refresh company map every 100 jobs
|
||
refresh_counter += 1
|
||
if refresh_counter % 100 == 0:
|
||
company_id_map = await _build_company_id_map(pool)
|
||
# Re-resolve extractor config (picks up active variant swaps)
|
||
try:
|
||
new_resolved = await resolver.resolve("document-extractor")
|
||
if new_resolved is not None:
|
||
new_provider = _get_provider(new_resolved)
|
||
new_cfg = build_config_from_resolved(
|
||
new_resolved, config.ollama, config.vllm,
|
||
)
|
||
old_provider = extractor_provider
|
||
provider_changed = new_provider != extractor_provider
|
||
model_changed = new_cfg.model != extractor_client._config.model
|
||
|
||
if provider_changed or model_changed:
|
||
# Guard: don't switch to ollama if base_url is empty
|
||
if new_provider == "ollama" and not config.ollama.base_url:
|
||
logger.warning(
|
||
"DB resolved provider=ollama but OLLAMA_BASE_URL is empty — "
|
||
"keeping current %s client. Fix the agent config in the UI.",
|
||
extractor_provider,
|
||
)
|
||
else:
|
||
logger.info(
|
||
"Extractor provider switch: old_provider=%s new_provider=%s "
|
||
"model=%s variant=%s",
|
||
old_provider, new_provider,
|
||
new_resolved.model_name, new_resolved.variant_id,
|
||
)
|
||
await extractor_client.close()
|
||
extractor_client = build_llm_client(
|
||
new_resolved, config.ollama, config.vllm,
|
||
)
|
||
extractor_provider = new_provider
|
||
else:
|
||
# Same provider and model — just update config in-place
|
||
extractor_client._config = new_cfg # type: ignore[assignment]
|
||
resolved_config = new_resolved
|
||
except Exception:
|
||
logger.warning("Failed to refresh extractor config", exc_info=True)
|
||
|
||
# Re-resolve event classifier config
|
||
try:
|
||
new_cls_resolved = await resolver.resolve("event-classifier")
|
||
if new_cls_resolved is not None:
|
||
new_cls_provider = _get_provider(new_cls_resolved)
|
||
new_cls_cfg = build_config_from_resolved(
|
||
new_cls_resolved, config.ollama, config.vllm,
|
||
)
|
||
old_cls_provider = classifier_provider
|
||
cls_provider_changed = new_cls_provider != classifier_provider
|
||
cls_model_changed = new_cls_cfg.model != classifier_client._config.model
|
||
|
||
if cls_provider_changed or cls_model_changed:
|
||
# Guard: don't switch to ollama if base_url is empty
|
||
if new_cls_provider == "ollama" and not config.ollama.base_url:
|
||
logger.warning(
|
||
"DB resolved classifier provider=ollama but OLLAMA_BASE_URL is empty — "
|
||
"keeping current %s client. Fix the agent config in the UI.",
|
||
classifier_provider,
|
||
)
|
||
else:
|
||
logger.info(
|
||
"Classifier provider switch: old_provider=%s new_provider=%s "
|
||
"model=%s variant=%s",
|
||
old_cls_provider, new_cls_provider,
|
||
new_cls_resolved.model_name, new_cls_resolved.variant_id,
|
||
)
|
||
if classifier_client is not extractor_client:
|
||
await classifier_client.close()
|
||
classifier_client = build_llm_client(
|
||
new_cls_resolved, config.ollama, config.vllm,
|
||
)
|
||
classifier_provider = new_cls_provider
|
||
elif classifier_client is extractor_client and new_cls_cfg.model != extractor_client._config.model:
|
||
classifier_client = build_llm_client(
|
||
new_cls_resolved, config.ollama, config.vllm,
|
||
)
|
||
classifier_provider = new_cls_provider
|
||
classifier_resolved = new_cls_resolved
|
||
except Exception:
|
||
logger.warning("Failed to refresh event-classifier config", exc_info=True)
|
||
|
||
# Route macro_event documents to event classification (Requirement 2.1)
|
||
doc_type = None
|
||
if is_macro_job:
|
||
doc_type = "macro_event"
|
||
else:
|
||
doc_type = await _fetch_document_type(pool, document_id)
|
||
|
||
if doc_type == "macro_event":
|
||
logger.info("Routing macro_event doc %s to event classifier", document_id)
|
||
await _process_macro_classification(
|
||
pool=pool,
|
||
minio_client=minio_client,
|
||
ollama=classifier_client,
|
||
redis_client=redis_client,
|
||
document_id=document_id,
|
||
text=text,
|
||
company_id_map=company_id_map,
|
||
confidence_threshold=confidence_threshold,
|
||
)
|
||
continue
|
||
|
||
# Standard extraction pipeline for non-macro documents
|
||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||
|
||
try:
|
||
# Token budget enforcement (Requirement 10.6)
|
||
if (
|
||
resolved_config is not None
|
||
and resolved_config.token_budget > 0
|
||
and resolved_config.variant_id is not None
|
||
):
|
||
budget_exceeded = await _check_token_budget(
|
||
pool, resolved_config.variant_id, resolved_config.token_budget,
|
||
)
|
||
if budget_exceeded:
|
||
continue
|
||
|
||
# Input token limit truncation (Requirement 10.5)
|
||
extraction_text = text
|
||
if resolved_config is not None and resolved_config.input_token_limit > 0:
|
||
# Rough estimate: ~4 chars per token
|
||
max_chars = resolved_config.input_token_limit * 4
|
||
if len(extraction_text) > max_chars:
|
||
extraction_text = extraction_text[:max_chars]
|
||
logger.info(
|
||
"Truncated input for doc %s from %d to %d chars (token limit %d)",
|
||
document_id, len(text), max_chars, resolved_config.input_token_limit,
|
||
)
|
||
|
||
# Pass all tracked tickers so the model can identify any mentioned companies
|
||
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
||
extraction_response = await extractor_client.extract(
|
||
extraction_text,
|
||
document_id=document_id,
|
||
known_tickers=all_tickers,
|
||
)
|
||
result = await persist_extraction(
|
||
pool=pool,
|
||
minio_client=minio_client,
|
||
document_id=document_id,
|
||
ticker=ticker,
|
||
extraction_response=extraction_response,
|
||
company_id_map=company_id_map,
|
||
document_text_length=len(extraction_text),
|
||
)
|
||
|
||
# Log to agent_performance_log with variant attribution
|
||
if resolved_config is not None:
|
||
output_tokens = 0
|
||
if extraction_response.attempts:
|
||
final = extraction_response.attempts[-1]
|
||
output_tokens = len(final.raw_output) // 4 if final.raw_output else 0
|
||
await _log_agent_performance(
|
||
pool,
|
||
agent_id=resolved_config.agent_id,
|
||
variant_id=resolved_config.variant_id,
|
||
document_id=document_id,
|
||
ticker=ticker,
|
||
success=extraction_response.success,
|
||
duration_ms=extraction_response.total_duration_ms,
|
||
confidence=extraction_response.result.confidence if extraction_response.result else 0.0,
|
||
retry_count=max(0, len(extraction_response.attempts) - 1),
|
||
input_tokens=len(extraction_text) // 4,
|
||
output_tokens=output_tokens,
|
||
error_message=(
|
||
extraction_response.attempts[-1].error
|
||
if extraction_response.attempts and extraction_response.attempts[-1].error
|
||
else None
|
||
),
|
||
)
|
||
|
||
# Enqueue aggregation job for the ticker on success
|
||
if result.success and ticker:
|
||
await redis_client.rpush(
|
||
agg_queue,
|
||
json.dumps(inject_trace_context({"ticker": ticker})),
|
||
)
|
||
except Exception:
|
||
logger.exception("Extraction failed for doc %s", document_id)
|
||
finally:
|
||
await pool.close()
|
||
await redis_client.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|