"""Extractor worker entrypoint - polls Redis for extraction jobs.""" from __future__ import annotations import asyncio import json import logging import asyncpg import redis.asyncio as aioredis from minio import Minio from services.aggregation.interpolation import ( build_default_profile, compute_macro_impact_with_sector, filter_low_confidence_events, persist_macro_impact_records, ) from services.extractor.client import OllamaClient from services.extractor.event_classifier import classify_global_event from services.extractor.worker import persist_extraction from services.shared.config import load_config from services.shared.logging import inject_trace_context, setup_logging from services.shared.redis_keys import ( QUEUE_AGGREGATION, QUEUE_EXTRACTION, QUEUE_MACRO_CLASSIFICATION, queue_key, ) logger = logging.getLogger("extractor_main") async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]: """Build a ticker -> company_id mapping from the companies table.""" rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE") return {row["ticker"]: str(row["id"]) for row in rows} async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None: """Fetch the document_type for a document.""" row = await pool.fetchrow( "SELECT document_type FROM documents WHERE id = $1::uuid", document_id, ) return row["document_type"] if row else None async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]: """Fetch company info needed for exposure profile loading and interpolation.""" rows = await pool.fetch( """SELECT id, ticker, sector, industry, market_cap_bucket FROM companies WHERE active = TRUE""" ) return [dict(r) for r in rows] async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str): """Load exposure profile for a company: manual > inferred > default. Requirements: 4.1 """ from services.shared.schemas import ExposureProfileSchema, MarketPositionTier # Try manual or inferred profile from DB row = await pool.fetchrow( """SELECT company_id, geographic_revenue_mix, supply_chain_regions, key_input_commodities, regulatory_jurisdictions, market_position_tier, export_dependency_pct, source, confidence, version FROM exposure_profiles WHERE company_id = $1 AND active = TRUE ORDER BY version DESC LIMIT 1""", company_id, ) if row: geo_mix = row["geographic_revenue_mix"] if isinstance(geo_mix, str): geo_mix = json.loads(geo_mix) tier_val = row["market_position_tier"] try: tier = MarketPositionTier(tier_val) except ValueError: tier = MarketPositionTier.REGIONAL return ExposureProfileSchema( company_id=str(row["company_id"]), geographic_revenue_mix=geo_mix or {}, supply_chain_regions=list(row["supply_chain_regions"] or []), key_input_commodities=list(row["key_input_commodities"] or []), regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []), market_position_tier=tier, export_dependency_pct=float(row["export_dependency_pct"] or 0.0), source=row["source"] or "manual", confidence=float(row["confidence"] or 1.0), version=row["version"] or 1, ) # Fall back to default profile profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap") profile.company_id = str(company_id) return profile async def _compute_and_persist_macro_impacts( pool: asyncpg.Pool, event, companies: list[dict], confidence_threshold: float = 0.4, ) -> list[str]: """Compute MacroImpactRecords for all tracked companies and persist non-zero ones. Requirements: 4.1, 4.5 """ # Filter low-confidence events filtered = filter_low_confidence_events([event], confidence_threshold) if not filtered: logger.info("Event %s excluded: confidence %.3f below threshold %.3f", event.event_id, event.confidence, confidence_threshold) return [] records = [] for company in companies: company_id = str(company["id"]) ticker = company["ticker"] sector = company.get("sector") or "" industry = company.get("industry") or "" market_cap_bucket = company.get("market_cap_bucket") or "small_cap" profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket) record = compute_macro_impact_with_sector(event, profile, company_sector=sector) record.ticker = ticker record.company_id = company_id if record.macro_impact_score > 0.0: records.append(record) if records: ids = await persist_macro_impact_records(pool, records) logger.info( "Persisted %d macro impact records for event %s", len(ids), event.event_id, ) return [r.ticker for r in records] return [] # Track consecutive macro classification failures for alerting (Requirement 10.4) _macro_consecutive_failures = 0 _MACRO_FAILURE_ALERT_THRESHOLD = 3 async def _process_macro_classification( *, pool: asyncpg.Pool, minio_client: Minio, ollama: OllamaClient, redis_client: aioredis.Redis, document_id: str, text: str, company_id_map: dict[str, str], confidence_threshold: float = 0.4, ) -> None: """Route a macro_event document to event classification, compute interpolation, and trigger aggregation for affected tickers. Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4 """ global _macro_consecutive_failures agg_queue = queue_key(QUEUE_AGGREGATION) try: event = await classify_global_event( normalized_text=text, document_id=document_id, ollama_client=ollama, pool=pool, minio_client=minio_client, ) logger.info( "Classified macro event %s for doc %s: severity=%s types=%s", event.event_id, document_id, event.severity, event.event_types, ) # Reset failure counter on success _macro_consecutive_failures = 0 # Load all tracked companies and compute macro impacts companies = await _fetch_company_info(pool) affected_tickers = await _compute_and_persist_macro_impacts( pool, event, companies, confidence_threshold, ) # Trigger aggregation for affected tickers (those with non-zero impact) enqueued_tickers = set() for ticker in affected_tickers: if ticker not in enqueued_tickers: await redis_client.rpush( agg_queue, json.dumps(inject_trace_context({ "ticker": ticker, "macro_event_id": event.event_id, })), ) enqueued_tickers.add(ticker) logger.info( "Enqueued aggregation jobs for %d affected tickers after macro event %s", len(enqueued_tickers), event.event_id, ) except ValueError as e: _macro_consecutive_failures += 1 logger.error("Macro event classification failed for doc %s: %s", document_id, e) if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD: logger.critical( "ALERT: Sustained macro classification failures (%d consecutive). " "Continuing with company-only signals. Operator action required.", _macro_consecutive_failures, ) except Exception: _macro_consecutive_failures += 1 logger.exception("Unexpected error classifying macro event for doc %s", document_id) if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD: logger.critical( "ALERT: Sustained macro classification failures (%d consecutive). " "Continuing with company-only signals. Operator action required.", _macro_consecutive_failures, ) async def main() -> None: config = load_config() setup_logging("extractor", level=config.log_level, json_output=config.json_logs) pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8) minio_client = Minio( config.minio.endpoint, access_key=config.minio.access_key, secret_key=config.minio.secret_key, secure=config.minio.secure, ) ollama = OllamaClient(config.ollama) redis_client = aioredis.from_url(config.redis.url) queue = queue_key(QUEUE_EXTRACTION) macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION) agg_queue = queue_key(QUEUE_AGGREGATION) confidence_threshold = config.macro.macro_confidence_threshold logger.info("Extractor worker started, polling %s and %s", queue, macro_queue) # Pre-load company ID map (refreshed periodically) company_id_map = await _build_company_id_map(pool) refresh_counter = 0 try: while True: # Check macro classification queue first (priority) raw = await redis_client.lpop(macro_queue) is_macro_job = raw is not None if raw is None: raw = await redis_client.lpop(queue) if raw is None: await asyncio.sleep(1) continue job = json.loads(raw) document_id = job.get("document_id", "") ticker = job.get("ticker", "") text = job.get("text", "") or job.get("normalized_text", "") # If no text in job, try to fetch from MinIO via the document's normalized_storage_ref if not text: ref_row = await pool.fetchrow( "SELECT normalized_storage_ref FROM documents WHERE id = $1::uuid", document_id, ) if ref_row and ref_row["normalized_storage_ref"]: try: ref = ref_row["normalized_storage_ref"] # ref format: s3://bucket/path parts = ref.replace("s3://", "").split("/", 1) if len(parts) == 2: obj = minio_client.get_object(parts[0], parts[1]) text = obj.read().decode("utf-8") obj.close() obj.release_conn() except Exception as e: logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e) # Refresh company map every 100 jobs refresh_counter += 1 if refresh_counter % 100 == 0: company_id_map = await _build_company_id_map(pool) # Route macro_event documents to event classification (Requirement 2.1) doc_type = None if is_macro_job: doc_type = "macro_event" else: doc_type = await _fetch_document_type(pool, document_id) if doc_type == "macro_event": logger.info("Routing macro_event doc %s to event classifier", document_id) await _process_macro_classification( pool=pool, minio_client=minio_client, ollama=ollama, redis_client=redis_client, document_id=document_id, text=text, company_id_map=company_id_map, confidence_threshold=confidence_threshold, ) continue # Standard extraction pipeline for non-macro documents logger.info("Processing extraction job for doc %s / %s", document_id, ticker) try: # Pass all tracked tickers so the model can identify any mentioned companies all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None) extraction_response = await ollama.extract( text, document_id=document_id, known_tickers=all_tickers, ) result = await persist_extraction( pool=pool, minio_client=minio_client, document_id=document_id, ticker=ticker, extraction_response=extraction_response, company_id_map=company_id_map, document_text_length=len(text), ) # Enqueue aggregation job for the ticker on success if result.success and ticker: await redis_client.rpush( agg_queue, json.dumps(inject_trace_context({"ticker": ticker})), ) except Exception: logger.exception("Extraction failed for doc %s", document_id) finally: await pool.close() await redis_client.close() if __name__ == "__main__": asyncio.run(main())