"""Extractor worker entrypoint - polls Redis for extraction jobs.""" from __future__ import annotations import asyncio import json import logging import asyncpg import redis.asyncio as aioredis from minio import Minio from services.aggregation.interpolation import ( build_default_profile, compute_macro_impact_with_sector, filter_low_confidence_events, persist_macro_impact_records, ) from services.extractor.event_classifier import classify_global_event from services.extractor.llm_factory import build_config_from_resolved, build_llm_client from services.extractor.vllm_client import check_vllm_health from services.extractor.worker import persist_extraction from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig from services.shared.config import OllamaConfig, load_config from services.shared.llm_protocol import LLMClient from services.shared.logging import inject_trace_context, setup_logging from services.shared.redis_keys import ( QUEUE_AGGREGATION, QUEUE_EXTRACTION, QUEUE_MACRO_CLASSIFICATION, is_pipeline_enabled, queue_key, ) logger = logging.getLogger("extractor_main") def _get_provider(resolved: ResolvedAgentConfig | None) -> str: """Return the normalised provider string for a resolved config.""" if resolved is None: return "ollama" return (resolved.model_provider or "").strip().lower() or "ollama" def _build_ollama_config_from_resolved( resolved: ResolvedAgentConfig, base_config: OllamaConfig, ) -> OllamaConfig: """Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings. Kept for backward compatibility — the factory's ``build_config_from_resolved`` is now the primary path. """ return OllamaConfig( base_url=base_config.base_url, model=resolved.model_name, timeout=resolved.timeout_seconds, max_retries=resolved.max_retries, retry_base_delay=base_config.retry_base_delay, retry_max_delay=base_config.retry_max_delay, retry_backoff_multiplier=base_config.retry_backoff_multiplier, max_tokens=resolved.max_tokens, stall_timeout=base_config.stall_timeout, loop_window=base_config.loop_window, loop_threshold=base_config.loop_threshold, context_window=resolved.context_window, ) async def _check_token_budget( pool: asyncpg.Pool, variant_id: str, token_budget: int, ) -> bool: """Check if a variant has exceeded its hourly token budget. Returns True if the budget is exceeded and invocation should be skipped. """ row = await pool.fetchrow( """SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total_tokens FROM agent_performance_log WHERE variant_id = $1 AND recorded_at >= NOW() - INTERVAL '1 hour'""", variant_id, ) used = int(row["total_tokens"]) if row else 0 if used >= token_budget: logger.warning( "Token budget exceeded for variant %s: used %d / budget %d — skipping invocation", variant_id, used, token_budget, ) return True return False async def _log_agent_performance( pool: asyncpg.Pool, *, agent_id: str, variant_id: str | None = None, document_id: str = "", ticker: str = "", success: bool = False, duration_ms: int = 0, confidence: float = 0.0, retry_count: int = 0, input_tokens: int = 0, output_tokens: int = 0, error_message: str | None = None, ) -> None: """Insert a row into agent_performance_log with optional variant attribution.""" try: await pool.execute( """INSERT INTO agent_performance_log (agent_id, variant_id, document_id, ticker, success, duration_ms, confidence, retry_count, input_tokens, output_tokens, error_message) VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8, $9, $10, $11)""", agent_id, variant_id, document_id if document_id else None, ticker, success, duration_ms, confidence, retry_count, input_tokens, output_tokens, error_message, ) except Exception: logger.warning("Failed to log agent performance", exc_info=True) async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]: """Build a ticker -> company_id mapping from the companies table.""" rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE") return {row["ticker"]: str(row["id"]) for row in rows} async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None: """Fetch the document_type for a document.""" row = await pool.fetchrow( "SELECT document_type FROM documents WHERE id = $1::uuid", document_id, ) return row["document_type"] if row else None async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]: """Fetch company info needed for exposure profile loading and interpolation.""" rows = await pool.fetch( """SELECT id, ticker, sector, industry, market_cap_bucket FROM companies WHERE active = TRUE""" ) return [dict(r) for r in rows] async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str): """Load exposure profile for a company: manual > inferred > default. Requirements: 4.1 """ from services.shared.schemas import ExposureProfileSchema, MarketPositionTier # Try manual or inferred profile from DB row = await pool.fetchrow( """SELECT company_id, geographic_revenue_mix, supply_chain_regions, key_input_commodities, regulatory_jurisdictions, market_position_tier, export_dependency_pct, source, confidence, version FROM exposure_profiles WHERE company_id = $1 AND active = TRUE ORDER BY version DESC LIMIT 1""", company_id, ) if row: geo_mix = row["geographic_revenue_mix"] if isinstance(geo_mix, str): geo_mix = json.loads(geo_mix) tier_val = row["market_position_tier"] try: tier = MarketPositionTier(tier_val) except ValueError: tier = MarketPositionTier.REGIONAL return ExposureProfileSchema( company_id=str(row["company_id"]), geographic_revenue_mix=geo_mix or {}, supply_chain_regions=list(row["supply_chain_regions"] or []), key_input_commodities=list(row["key_input_commodities"] or []), regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []), market_position_tier=tier, export_dependency_pct=float(row["export_dependency_pct"] or 0.0), source=row["source"] or "manual", confidence=float(row["confidence"] or 1.0), version=row["version"] or 1, ) # Fall back to default profile profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap") profile.company_id = str(company_id) return profile async def _compute_and_persist_macro_impacts( pool: asyncpg.Pool, event, companies: list[dict], confidence_threshold: float = 0.4, ) -> list[str]: """Compute MacroImpactRecords for all tracked companies and persist non-zero ones. Requirements: 4.1, 4.5 """ # Filter low-confidence events filtered = filter_low_confidence_events([event], confidence_threshold) if not filtered: logger.info("Event %s excluded: confidence %.3f below threshold %.3f", event.event_id, event.confidence, confidence_threshold) return [] records = [] for company in companies: company_id = str(company["id"]) ticker = company["ticker"] sector = company.get("sector") or "" industry = company.get("industry") or "" market_cap_bucket = company.get("market_cap_bucket") or "small_cap" profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket) record = compute_macro_impact_with_sector(event, profile, company_sector=sector) record.ticker = ticker record.company_id = company_id if record.macro_impact_score > 0.0: records.append(record) if records: ids = await persist_macro_impact_records(pool, records) logger.info( "Persisted %d macro impact records for event %s", len(ids), event.event_id, ) return [r.ticker for r in records] return [] # Track consecutive macro classification failures for alerting (Requirement 10.4) _macro_consecutive_failures = 0 _MACRO_FAILURE_ALERT_THRESHOLD = 3 async def _process_macro_classification( *, pool: asyncpg.Pool, minio_client: Minio, ollama: LLMClient, redis_client: aioredis.Redis, document_id: str, text: str, company_id_map: dict[str, str], confidence_threshold: float = 0.4, ) -> None: """Route a macro_event document to event classification, compute interpolation, and trigger aggregation for affected tickers. Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4 """ global _macro_consecutive_failures agg_queue = queue_key(QUEUE_AGGREGATION) try: event = await classify_global_event( normalized_text=text, document_id=document_id, client=ollama, pool=pool, minio_client=minio_client, ) logger.info( "Classified macro event %s for doc %s: severity=%s types=%s", event.event_id, document_id, event.severity, event.event_types, ) # Reset failure counter on success _macro_consecutive_failures = 0 # Load all tracked companies and compute macro impacts companies = await _fetch_company_info(pool) affected_tickers = await _compute_and_persist_macro_impacts( pool, event, companies, confidence_threshold, ) # Trigger aggregation for affected tickers (those with non-zero impact) enqueued_tickers = set() for ticker in affected_tickers: if ticker not in enqueued_tickers: await redis_client.rpush( agg_queue, json.dumps(inject_trace_context({ "ticker": ticker, "macro_event_id": event.event_id, })), ) enqueued_tickers.add(ticker) logger.info( "Enqueued aggregation jobs for %d affected tickers after macro event %s", len(enqueued_tickers), event.event_id, ) except ValueError as e: _macro_consecutive_failures += 1 logger.error("Macro event classification failed for doc %s: %s", document_id, e) if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD: logger.critical( "ALERT: Sustained macro classification failures (%d consecutive). " "Continuing with company-only signals. Operator action required.", _macro_consecutive_failures, ) except Exception: _macro_consecutive_failures += 1 logger.exception("Unexpected error classifying macro event for doc %s", document_id) if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD: logger.critical( "ALERT: Sustained macro classification failures (%d consecutive). " "Continuing with company-only signals. Operator action required.", _macro_consecutive_failures, ) async def main() -> None: config = load_config() setup_logging("extractor", level=config.log_level, json_output=config.json_logs) pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8) minio_client = Minio( config.minio.endpoint, access_key=config.minio.access_key, secret_key=config.minio.secret_key, secure=config.minio.secure, ) # Resolve extractor config from DB (active variant override + TTL cache) resolver = AgentConfigResolver(pool, ttl_seconds=60) resolved_config: ResolvedAgentConfig | None = None extractor_provider = "ollama" try: resolved_config = await resolver.resolve("document-extractor") if resolved_config is not None: extractor_provider = _get_provider(resolved_config) logger.info( "Extractor using resolved config: model=%s variant=%s provider=%s", resolved_config.model_name, resolved_config.variant_id, extractor_provider, ) else: logger.info("No DB config for document-extractor — using env defaults") except Exception: logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True) # vLLM health check at startup when provider is vllm (Requirement 7.1–7.3) if extractor_provider == "vllm": healthy = await check_vllm_health(config.vllm.base_url) if not healthy: logger.warning( "vLLM health check failed at startup — falling back to Ollama for extractor", ) extractor_provider = "ollama" # Override resolved config provider so factory builds OllamaClient resolved_config = None extractor_client: LLMClient = build_llm_client( resolved_config, config.ollama, config.vllm, ) # Resolve event classifier config separately (may use different model) classifier_resolved: ResolvedAgentConfig | None = None classifier_provider = "ollama" try: classifier_resolved = await resolver.resolve("event-classifier") if classifier_resolved is not None: classifier_provider = _get_provider(classifier_resolved) logger.info( "Event classifier using resolved config: model=%s variant=%s provider=%s", classifier_resolved.model_name, classifier_resolved.variant_id, classifier_provider, ) else: logger.info("No DB config for event-classifier — using extractor config") except Exception: logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True) # vLLM health check for classifier if it uses vllm and extractor didn't already check if classifier_provider == "vllm" and extractor_provider != "vllm": healthy = await check_vllm_health(config.vllm.base_url) if not healthy: logger.warning( "vLLM health check failed at startup — falling back to Ollama for classifier", ) classifier_provider = "ollama" classifier_resolved = None # Build classifier client — share with extractor when configs match classifier_client: LLMClient if classifier_resolved is not None or classifier_provider != extractor_provider: classifier_client = build_llm_client( classifier_resolved, config.ollama, config.vllm, ) else: classifier_client = extractor_client redis_client = aioredis.from_url(config.redis.url) queue = queue_key(QUEUE_EXTRACTION) macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION) agg_queue = queue_key(QUEUE_AGGREGATION) confidence_threshold = config.macro.macro_confidence_threshold logger.info("Extractor worker started, polling %s and %s", queue, macro_queue) # Pre-load company ID map (refreshed periodically) company_id_map = await _build_company_id_map(pool) refresh_counter = 0 # Alternate between queues to prevent starvation: process 1 macro then 2 extractions macro_turn_counter = 0 try: while True: if not await is_pipeline_enabled(redis_client): await asyncio.sleep(1) continue # Alternate: every 3rd job from macro queue, rest from extraction # This prevents macro events from starving regular extractions raw = None is_macro_job = False if macro_turn_counter % 3 == 0: # Try macro first raw = await redis_client.lpop(macro_queue) is_macro_job = raw is not None if raw is None: raw = await redis_client.lpop(queue) else: # Try extraction first raw = await redis_client.lpop(queue) if raw is None: raw = await redis_client.lpop(macro_queue) is_macro_job = raw is not None macro_turn_counter += 1 if raw is None: await asyncio.sleep(1) continue job = json.loads(raw) document_id = job.get("document_id", "") ticker = job.get("ticker", "") text = job.get("text", "") or job.get("normalized_text", "") # If no text in job, try to fetch from MinIO via the document's normalized_storage_ref if not text: ref_row = await pool.fetchrow( "SELECT normalized_storage_ref FROM documents WHERE id = $1::uuid", document_id, ) if ref_row and ref_row["normalized_storage_ref"]: try: ref = ref_row["normalized_storage_ref"] # ref format: s3://bucket/path parts = ref.replace("s3://", "").split("/", 1) if len(parts) == 2: obj = minio_client.get_object(parts[0], parts[1]) text = obj.read().decode("utf-8") obj.close() obj.release_conn() except Exception as e: logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e) # Refresh company map every 100 jobs refresh_counter += 1 if refresh_counter % 100 == 0: company_id_map = await _build_company_id_map(pool) # Re-resolve extractor config (picks up active variant swaps) try: new_resolved = await resolver.resolve("document-extractor") if new_resolved is not None: new_provider = _get_provider(new_resolved) new_cfg = build_config_from_resolved( new_resolved, config.ollama, config.vllm, ) old_provider = extractor_provider provider_changed = new_provider != extractor_provider model_changed = new_cfg.model != extractor_client._config.model if provider_changed or model_changed: # Guard: don't switch to ollama if base_url is empty if new_provider == "ollama" and not config.ollama.base_url: logger.warning( "DB resolved provider=ollama but OLLAMA_BASE_URL is empty — " "keeping current %s client. Fix the agent config in the UI.", extractor_provider, ) else: logger.info( "Extractor provider switch: old_provider=%s new_provider=%s " "model=%s variant=%s", old_provider, new_provider, new_resolved.model_name, new_resolved.variant_id, ) await extractor_client.close() extractor_client = build_llm_client( new_resolved, config.ollama, config.vllm, ) extractor_provider = new_provider else: # Same provider and model — just update config in-place extractor_client._config = new_cfg # type: ignore[assignment] resolved_config = new_resolved except Exception: logger.warning("Failed to refresh extractor config", exc_info=True) # Re-resolve event classifier config try: new_cls_resolved = await resolver.resolve("event-classifier") if new_cls_resolved is not None: new_cls_provider = _get_provider(new_cls_resolved) new_cls_cfg = build_config_from_resolved( new_cls_resolved, config.ollama, config.vllm, ) old_cls_provider = classifier_provider cls_provider_changed = new_cls_provider != classifier_provider cls_model_changed = new_cls_cfg.model != classifier_client._config.model if cls_provider_changed or cls_model_changed: # Guard: don't switch to ollama if base_url is empty if new_cls_provider == "ollama" and not config.ollama.base_url: logger.warning( "DB resolved classifier provider=ollama but OLLAMA_BASE_URL is empty — " "keeping current %s client. Fix the agent config in the UI.", classifier_provider, ) else: logger.info( "Classifier provider switch: old_provider=%s new_provider=%s " "model=%s variant=%s", old_cls_provider, new_cls_provider, new_cls_resolved.model_name, new_cls_resolved.variant_id, ) if classifier_client is not extractor_client: await classifier_client.close() classifier_client = build_llm_client( new_cls_resolved, config.ollama, config.vllm, ) classifier_provider = new_cls_provider elif classifier_client is extractor_client and new_cls_cfg.model != extractor_client._config.model: classifier_client = build_llm_client( new_cls_resolved, config.ollama, config.vllm, ) classifier_provider = new_cls_provider classifier_resolved = new_cls_resolved except Exception: logger.warning("Failed to refresh event-classifier config", exc_info=True) # Route macro_event documents to event classification (Requirement 2.1) doc_type = None if is_macro_job: doc_type = "macro_event" else: doc_type = await _fetch_document_type(pool, document_id) if doc_type == "macro_event": logger.info("Routing macro_event doc %s to event classifier", document_id) await _process_macro_classification( pool=pool, minio_client=minio_client, ollama=classifier_client, redis_client=redis_client, document_id=document_id, text=text, company_id_map=company_id_map, confidence_threshold=confidence_threshold, ) continue # Standard extraction pipeline for non-macro documents logger.info("Processing extraction job for doc %s / %s", document_id, ticker) try: # Token budget enforcement (Requirement 10.6) if ( resolved_config is not None and resolved_config.token_budget > 0 and resolved_config.variant_id is not None ): budget_exceeded = await _check_token_budget( pool, resolved_config.variant_id, resolved_config.token_budget, ) if budget_exceeded: continue # Input token limit truncation (Requirement 10.5) extraction_text = text if resolved_config is not None and resolved_config.input_token_limit > 0: # Rough estimate: ~4 chars per token max_chars = resolved_config.input_token_limit * 4 if len(extraction_text) > max_chars: extraction_text = extraction_text[:max_chars] logger.info( "Truncated input for doc %s from %d to %d chars (token limit %d)", document_id, len(text), max_chars, resolved_config.input_token_limit, ) # Pass all tracked tickers so the model can identify any mentioned companies all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None) extraction_response = await extractor_client.extract( extraction_text, document_id=document_id, known_tickers=all_tickers, ) result = await persist_extraction( pool=pool, minio_client=minio_client, document_id=document_id, ticker=ticker, extraction_response=extraction_response, company_id_map=company_id_map, document_text_length=len(extraction_text), ) # Log to agent_performance_log with variant attribution if resolved_config is not None: output_tokens = 0 if extraction_response.attempts: final = extraction_response.attempts[-1] output_tokens = len(final.raw_output) // 4 if final.raw_output else 0 await _log_agent_performance( pool, agent_id=resolved_config.agent_id, variant_id=resolved_config.variant_id, document_id=document_id, ticker=ticker, success=extraction_response.success, duration_ms=extraction_response.total_duration_ms, confidence=extraction_response.result.confidence if extraction_response.result else 0.0, retry_count=max(0, len(extraction_response.attempts) - 1), input_tokens=len(extraction_text) // 4, output_tokens=output_tokens, error_message=( extraction_response.attempts[-1].error if extraction_response.attempts and extraction_response.attempts[-1].error else None ), ) # Enqueue aggregation job for the ticker on success if result.success and ticker: await redis_client.rpush( agg_queue, json.dumps(inject_trace_context({"ticker": ticker})), ) except Exception: logger.exception("Extraction failed for doc %s", document_id) finally: await pool.close() await redis_client.close() if __name__ == "__main__": asyncio.run(main())