diff --git a/services/extractor/main.py b/services/extractor/main.py index 1d23839..2faa4cf 100644 --- a/services/extractor/main.py +++ b/services/extractor/main.py @@ -18,7 +18,8 @@ from services.aggregation.interpolation import ( from services.extractor.client import OllamaClient from services.extractor.event_classifier import classify_global_event from services.extractor.worker import persist_extraction -from services.shared.config import load_config +from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig +from services.shared.config import OllamaConfig, load_config from services.shared.logging import inject_trace_context, setup_logging from services.shared.redis_keys import ( QUEUE_AGGREGATION, @@ -30,6 +31,91 @@ from services.shared.redis_keys import ( logger = logging.getLogger("extractor_main") +def _build_ollama_config_from_resolved( + resolved: ResolvedAgentConfig, + base_config: OllamaConfig, +) -> OllamaConfig: + """Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings.""" + return OllamaConfig( + base_url=base_config.base_url, + model=resolved.model_name, + timeout=resolved.timeout_seconds, + max_retries=resolved.max_retries, + retry_base_delay=base_config.retry_base_delay, + retry_max_delay=base_config.retry_max_delay, + retry_backoff_multiplier=base_config.retry_backoff_multiplier, + max_tokens=resolved.max_tokens, + stall_timeout=base_config.stall_timeout, + loop_window=base_config.loop_window, + loop_threshold=base_config.loop_threshold, + context_window=resolved.context_window, + ) + + +async def _check_token_budget( + pool: asyncpg.Pool, + variant_id: str, + token_budget: int, +) -> bool: + """Check if a variant has exceeded its hourly token budget. + + Returns True if the budget is exceeded and invocation should be skipped. + """ + row = await pool.fetchrow( + """SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total_tokens + FROM agent_performance_log + WHERE variant_id = $1 + AND recorded_at >= NOW() - INTERVAL '1 hour'""", + variant_id, + ) + used = int(row["total_tokens"]) if row else 0 + if used >= token_budget: + logger.warning( + "Token budget exceeded for variant %s: used %d / budget %d — skipping invocation", + variant_id, used, token_budget, + ) + return True + return False + + +async def _log_agent_performance( + pool: asyncpg.Pool, + *, + agent_id: str, + variant_id: str | None = None, + document_id: str = "", + ticker: str = "", + success: bool = False, + duration_ms: int = 0, + confidence: float = 0.0, + retry_count: int = 0, + input_tokens: int = 0, + output_tokens: int = 0, + error_message: str | None = None, +) -> None: + """Insert a row into agent_performance_log with optional variant attribution.""" + try: + await pool.execute( + """INSERT INTO agent_performance_log + (agent_id, variant_id, document_id, ticker, success, duration_ms, + confidence, retry_count, input_tokens, output_tokens, error_message) + VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8, $9, $10, $11)""", + agent_id, + variant_id, + document_id if document_id else None, + ticker, + success, + duration_ms, + confidence, + retry_count, + input_tokens, + output_tokens, + error_message, + ) + except Exception: + logger.warning("Failed to log agent performance", exc_info=True) + + async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]: """Build a ticker -> company_id mapping from the companies table.""" rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE") @@ -239,7 +325,53 @@ async def main() -> None: secret_key=config.minio.secret_key, secure=config.minio.secure, ) - ollama = OllamaClient(config.ollama) + + # Resolve extractor config from DB (active variant override + TTL cache) + resolver = AgentConfigResolver(pool, ttl_seconds=60) + resolved_config: ResolvedAgentConfig | None = None + extractor_ollama_config = config.ollama + try: + resolved_config = await resolver.resolve("document-extractor") + if resolved_config is not None: + extractor_ollama_config = _build_ollama_config_from_resolved( + resolved_config, config.ollama, + ) + logger.info( + "Extractor using resolved config: model=%s variant=%s", + resolved_config.model_name, resolved_config.variant_id, + ) + else: + logger.info("No DB config for document-extractor — using env defaults") + except Exception: + logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True) + + ollama = OllamaClient(extractor_ollama_config) + + # Resolve event classifier config separately (may use different model) + classifier_resolved: ResolvedAgentConfig | None = None + classifier_ollama_config = config.ollama + try: + classifier_resolved = await resolver.resolve("event-classifier") + if classifier_resolved is not None: + classifier_ollama_config = _build_ollama_config_from_resolved( + classifier_resolved, config.ollama, + ) + logger.info( + "Event classifier using resolved config: model=%s variant=%s", + classifier_resolved.model_name, classifier_resolved.variant_id, + ) + else: + logger.info("No DB config for event-classifier — using extractor config") + except Exception: + logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True) + + # Use a separate OllamaClient for the classifier if it has a different model + classifier_ollama: OllamaClient + if classifier_ollama_config.model != extractor_ollama_config.model: + classifier_ollama = OllamaClient(classifier_ollama_config) + else: + classifier_ollama = ollama + redis_client = aioredis.from_url(config.redis.url) queue = queue_key(QUEUE_EXTRACTION) macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION) @@ -307,6 +439,44 @@ async def main() -> None: refresh_counter += 1 if refresh_counter % 100 == 0: company_id_map = await _build_company_id_map(pool) + # Re-resolve extractor config (picks up active variant swaps) + try: + resolved_config = await resolver.resolve("document-extractor") + if resolved_config is not None: + new_ollama_cfg = _build_ollama_config_from_resolved( + resolved_config, config.ollama, + ) + if new_ollama_cfg.model != ollama._config.model: + logger.info( + "Extractor config changed: model=%s variant=%s", + resolved_config.model_name, resolved_config.variant_id, + ) + await ollama.close() + ollama = OllamaClient(new_ollama_cfg) + else: + ollama._config = new_ollama_cfg + except Exception: + logger.warning("Failed to refresh extractor config", exc_info=True) + + # Re-resolve event classifier config + try: + classifier_resolved = await resolver.resolve("event-classifier") + if classifier_resolved is not None: + new_cls_cfg = _build_ollama_config_from_resolved( + classifier_resolved, config.ollama, + ) + if new_cls_cfg.model != classifier_ollama._config.model: + logger.info( + "Event classifier config changed: model=%s variant=%s", + classifier_resolved.model_name, classifier_resolved.variant_id, + ) + if classifier_ollama is not ollama: + await classifier_ollama.close() + classifier_ollama = OllamaClient(new_cls_cfg) + elif classifier_ollama is ollama and new_cls_cfg.model != ollama._config.model: + classifier_ollama = OllamaClient(new_cls_cfg) + except Exception: + logger.warning("Failed to refresh event-classifier config", exc_info=True) # Route macro_event documents to event classification (Requirement 2.1) doc_type = None @@ -320,7 +490,7 @@ async def main() -> None: await _process_macro_classification( pool=pool, minio_client=minio_client, - ollama=ollama, + ollama=classifier_ollama, redis_client=redis_client, document_id=document_id, text=text, @@ -333,10 +503,34 @@ async def main() -> None: logger.info("Processing extraction job for doc %s / %s", document_id, ticker) try: + # Token budget enforcement (Requirement 10.6) + if ( + resolved_config is not None + and resolved_config.token_budget > 0 + and resolved_config.variant_id is not None + ): + budget_exceeded = await _check_token_budget( + pool, resolved_config.variant_id, resolved_config.token_budget, + ) + if budget_exceeded: + continue + + # Input token limit truncation (Requirement 10.5) + extraction_text = text + if resolved_config is not None and resolved_config.input_token_limit > 0: + # Rough estimate: ~4 chars per token + max_chars = resolved_config.input_token_limit * 4 + if len(extraction_text) > max_chars: + extraction_text = extraction_text[:max_chars] + logger.info( + "Truncated input for doc %s from %d to %d chars (token limit %d)", + document_id, len(text), max_chars, resolved_config.input_token_limit, + ) + # Pass all tracked tickers so the model can identify any mentioned companies all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None) extraction_response = await ollama.extract( - text, + extraction_text, document_id=document_id, known_tickers=all_tickers, ) @@ -347,9 +541,34 @@ async def main() -> None: ticker=ticker, extraction_response=extraction_response, company_id_map=company_id_map, - document_text_length=len(text), + document_text_length=len(extraction_text), ) + # Log to agent_performance_log with variant attribution + if resolved_config is not None: + output_tokens = 0 + if extraction_response.attempts: + final = extraction_response.attempts[-1] + output_tokens = len(final.raw_output) // 4 if final.raw_output else 0 + await _log_agent_performance( + pool, + agent_id=resolved_config.agent_id, + variant_id=resolved_config.variant_id, + document_id=document_id, + ticker=ticker, + success=extraction_response.success, + duration_ms=extraction_response.total_duration_ms, + confidence=extraction_response.result.confidence if extraction_response.result else 0.0, + retry_count=max(0, len(extraction_response.attempts) - 1), + input_tokens=len(extraction_text) // 4, + output_tokens=output_tokens, + error_message=( + extraction_response.attempts[-1].error + if extraction_response.attempts and extraction_response.attempts[-1].error + else None + ), + ) + # Enqueue aggregation job for the ticker on success if result.success and ticker: await redis_client.rpush( diff --git a/services/recommendation/main.py b/services/recommendation/main.py index 84f8f9f..74c36c0 100644 --- a/services/recommendation/main.py +++ b/services/recommendation/main.py @@ -9,7 +9,8 @@ import asyncpg from minio import Minio from services.recommendation.worker import generate_recommendation -from services.shared.config import load_config +from services.shared.agent_config import AgentConfigResolver +from services.shared.config import OllamaConfig, load_config from services.shared.logging import setup_logging from services.shared.redis_keys import QUEUE_RECOMMENDATION, queue_key @@ -32,8 +33,33 @@ async def main() -> None: redis_client = aioredis.from_url(config.redis.url) queue = queue_key(QUEUE_RECOMMENDATION) + + # Resolve thesis rewriter config from DB + resolver = AgentConfigResolver(pool, ttl_seconds=60) + ollama_config: OllamaConfig | None = None + try: + resolved = await resolver.resolve("thesis-rewriter") + if resolved is not None: + ollama_config = OllamaConfig( + base_url=config.ollama.base_url, + model=resolved.model_name, + timeout=resolved.timeout_seconds, + max_retries=resolved.max_retries, + max_tokens=resolved.max_tokens, + ) + logger.info( + "Thesis rewriter enabled: model=%s variant=%s", + resolved.model_name, resolved.variant_id, + ) + else: + logger.info("No DB config for thesis-rewriter — thesis rewriting disabled") + except Exception: + logger.warning("Failed to resolve thesis-rewriter config — thesis rewriting disabled", exc_info=True) + logger.info("Recommendation worker started, polling %s", queue) + refresh_counter = 0 + try: while True: raw = await redis_client.lpop(queue) @@ -48,10 +74,30 @@ async def main() -> None: logger.info("Processing recommendation job for %s/%s", ticker, window) + # Refresh resolver every 50 jobs to pick up config changes + refresh_counter += 1 + if refresh_counter % 50 == 0: + try: + resolved = await resolver.resolve("thesis-rewriter") + if resolved is not None: + new_config = OllamaConfig( + base_url=config.ollama.base_url, + model=resolved.model_name, + timeout=resolved.timeout_seconds, + max_retries=resolved.max_retries, + max_tokens=resolved.max_tokens, + ) + if ollama_config is None or new_config.model != ollama_config.model: + logger.info("Thesis rewriter config updated: model=%s", resolved.model_name) + ollama_config = new_config + except Exception: + logger.warning("Failed to refresh thesis-rewriter config", exc_info=True) + try: rec = await generate_recommendation( pool, ticker, window, minio_client=minio_client, + ollama_config=ollama_config, ) if rec: logger.info( diff --git a/tests/test_agent_config_resolver.py b/tests/test_agent_config_resolver.py new file mode 100644 index 0000000..5bf7a60 --- /dev/null +++ b/tests/test_agent_config_resolver.py @@ -0,0 +1,165 @@ +"""Tests for AgentConfigResolver — validates config resolution logic.""" +from __future__ import annotations + +import time +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig + + +def _make_row( + agent_id: str = "agent-1", + variant_id: str | None = None, + model_name: str = "qwen3.5:9b", + system_prompt: str = "test prompt", + **kwargs, +) -> dict: + """Build a mock DB row for the resolver.""" + return { + "agent_id": agent_id, + "variant_id": variant_id, + "model_provider": kwargs.get("model_provider", "ollama"), + "model_name": model_name, + "system_prompt": system_prompt, + "user_prompt_template": kwargs.get("user_prompt_template", ""), + "prompt_version": kwargs.get("prompt_version", "v1"), + "temperature": kwargs.get("temperature", 0.0), + "max_tokens": kwargs.get("max_tokens", 32768), + "context_window": kwargs.get("context_window", 0), + "input_token_limit": kwargs.get("input_token_limit", 0), + "token_budget": kwargs.get("token_budget", 0), + "timeout_seconds": kwargs.get("timeout_seconds", 120), + "max_retries": kwargs.get("max_retries", 2), + } + + +@pytest.mark.asyncio +async def test_resolve_returns_config_from_db(): + """Resolver returns a ResolvedAgentConfig when the DB has a matching agent.""" + pool = AsyncMock() + pool.fetchrow = AsyncMock(return_value=_make_row(model_name="qwen3.5:9b-fast")) + + resolver = AgentConfigResolver(pool, ttl_seconds=60) + config = await resolver.resolve("document-extractor") + + assert config is not None + assert config.model_name == "qwen3.5:9b-fast" + assert config.agent_id == "agent-1" + assert config.variant_id is None + + +@pytest.mark.asyncio +async def test_resolve_returns_variant_when_active(): + """Resolver returns variant config when an active variant exists.""" + pool = AsyncMock() + pool.fetchrow = AsyncMock(return_value=_make_row( + variant_id="variant-1", + model_name="llama3.1:8b", + system_prompt="variant prompt", + )) + + resolver = AgentConfigResolver(pool, ttl_seconds=60) + config = await resolver.resolve("document-extractor") + + assert config is not None + assert config.variant_id == "variant-1" + assert config.model_name == "llama3.1:8b" + assert config.system_prompt == "variant prompt" + + +@pytest.mark.asyncio +async def test_resolve_returns_none_when_not_found(): + """Resolver returns None when no agent matches the slug.""" + pool = AsyncMock() + pool.fetchrow = AsyncMock(return_value=None) + + resolver = AgentConfigResolver(pool, ttl_seconds=60) + config = await resolver.resolve("nonexistent-agent") + + assert config is None + + +@pytest.mark.asyncio +async def test_resolve_returns_none_on_db_error(): + """Resolver returns None and doesn't crash when DB query fails.""" + pool = AsyncMock() + pool.fetchrow = AsyncMock(side_effect=Exception("connection refused")) + + resolver = AgentConfigResolver(pool, ttl_seconds=60) + config = await resolver.resolve("document-extractor") + + assert config is None + + +@pytest.mark.asyncio +async def test_resolve_uses_cache_within_ttl(): + """Resolver caches results and doesn't re-query within TTL.""" + pool = AsyncMock() + pool.fetchrow = AsyncMock(return_value=_make_row(model_name="cached-model")) + + resolver = AgentConfigResolver(pool, ttl_seconds=300) + + config1 = await resolver.resolve("document-extractor") + config2 = await resolver.resolve("document-extractor") + + assert config1 is not None + assert config2 is not None + assert config1.model_name == "cached-model" + assert config2.model_name == "cached-model" + # Should only query DB once + assert pool.fetchrow.call_count == 1 + + +@pytest.mark.asyncio +async def test_resolve_refreshes_after_ttl(): + """Resolver re-queries DB after TTL expires.""" + pool = AsyncMock() + pool.fetchrow = AsyncMock(return_value=_make_row(model_name="model-v1")) + + resolver = AgentConfigResolver(pool, ttl_seconds=0) # 0 TTL = always expired + + config1 = await resolver.resolve("document-extractor") + pool.fetchrow = AsyncMock(return_value=_make_row(model_name="model-v2")) + config2 = await resolver.resolve("document-extractor") + + assert config1 is not None + assert config1.model_name == "model-v1" + assert config2 is not None + assert config2.model_name == "model-v2" + + +@pytest.mark.asyncio +async def test_invalidate_clears_cache(): + """invalidate() forces the next resolve to re-query.""" + pool = AsyncMock() + pool.fetchrow = AsyncMock(return_value=_make_row(model_name="original")) + + resolver = AgentConfigResolver(pool, ttl_seconds=300) + + config1 = await resolver.resolve("document-extractor") + assert config1 is not None + + resolver.invalidate("document-extractor") + pool.fetchrow = AsyncMock(return_value=_make_row(model_name="updated")) + + config2 = await resolver.resolve("document-extractor") + assert config2 is not None + assert config2.model_name == "updated" + + +@pytest.mark.asyncio +async def test_invalidate_all_clears_entire_cache(): + """invalidate(None) clears all cached entries.""" + pool = AsyncMock() + pool.fetchrow = AsyncMock(return_value=_make_row()) + + resolver = AgentConfigResolver(pool, ttl_seconds=300) + + await resolver.resolve("agent-a") + await resolver.resolve("agent-b") + assert len(resolver._cache) == 2 + + resolver.invalidate() + assert len(resolver._cache) == 0