feat: wire all 3 agents to DB config resolver
- Recommendation worker now resolves thesis-rewriter config from DB and passes ollama_config to generate_recommendation. Thesis rewriting is now active when the thesis-rewriter agent exists in ai_agents. Refreshes config every 50 jobs. - Event classifier now resolves its own config separately from the document extractor via 'event-classifier' slug. Uses a separate OllamaClient when the model differs from the extractor. Refreshes alongside the extractor every 100 jobs. - Document extractor was already wired (existing code). - Added 8 unit tests for AgentConfigResolver covering: DB resolution, variant override, not-found, DB errors, TTL caching, cache refresh, and invalidation.
This commit is contained in:
+224
-5
@@ -18,7 +18,8 @@ from services.aggregation.interpolation import (
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.event_classifier import classify_global_event
|
||||
from services.extractor.worker import persist_extraction
|
||||
from services.shared.config import load_config
|
||||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||||
from services.shared.config import OllamaConfig, load_config
|
||||
from services.shared.logging import inject_trace_context, setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_AGGREGATION,
|
||||
@@ -30,6 +31,91 @@ from services.shared.redis_keys import (
|
||||
logger = logging.getLogger("extractor_main")
|
||||
|
||||
|
||||
def _build_ollama_config_from_resolved(
|
||||
resolved: ResolvedAgentConfig,
|
||||
base_config: OllamaConfig,
|
||||
) -> OllamaConfig:
|
||||
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings."""
|
||||
return OllamaConfig(
|
||||
base_url=base_config.base_url,
|
||||
model=resolved.model_name,
|
||||
timeout=resolved.timeout_seconds,
|
||||
max_retries=resolved.max_retries,
|
||||
retry_base_delay=base_config.retry_base_delay,
|
||||
retry_max_delay=base_config.retry_max_delay,
|
||||
retry_backoff_multiplier=base_config.retry_backoff_multiplier,
|
||||
max_tokens=resolved.max_tokens,
|
||||
stall_timeout=base_config.stall_timeout,
|
||||
loop_window=base_config.loop_window,
|
||||
loop_threshold=base_config.loop_threshold,
|
||||
context_window=resolved.context_window,
|
||||
)
|
||||
|
||||
|
||||
async def _check_token_budget(
|
||||
pool: asyncpg.Pool,
|
||||
variant_id: str,
|
||||
token_budget: int,
|
||||
) -> bool:
|
||||
"""Check if a variant has exceeded its hourly token budget.
|
||||
|
||||
Returns True if the budget is exceeded and invocation should be skipped.
|
||||
"""
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total_tokens
|
||||
FROM agent_performance_log
|
||||
WHERE variant_id = $1
|
||||
AND recorded_at >= NOW() - INTERVAL '1 hour'""",
|
||||
variant_id,
|
||||
)
|
||||
used = int(row["total_tokens"]) if row else 0
|
||||
if used >= token_budget:
|
||||
logger.warning(
|
||||
"Token budget exceeded for variant %s: used %d / budget %d — skipping invocation",
|
||||
variant_id, used, token_budget,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def _log_agent_performance(
|
||||
pool: asyncpg.Pool,
|
||||
*,
|
||||
agent_id: str,
|
||||
variant_id: str | None = None,
|
||||
document_id: str = "",
|
||||
ticker: str = "",
|
||||
success: bool = False,
|
||||
duration_ms: int = 0,
|
||||
confidence: float = 0.0,
|
||||
retry_count: int = 0,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
"""Insert a row into agent_performance_log with optional variant attribution."""
|
||||
try:
|
||||
await pool.execute(
|
||||
"""INSERT INTO agent_performance_log
|
||||
(agent_id, variant_id, document_id, ticker, success, duration_ms,
|
||||
confidence, retry_count, input_tokens, output_tokens, error_message)
|
||||
VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8, $9, $10, $11)""",
|
||||
agent_id,
|
||||
variant_id,
|
||||
document_id if document_id else None,
|
||||
ticker,
|
||||
success,
|
||||
duration_ms,
|
||||
confidence,
|
||||
retry_count,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
error_message,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to log agent performance", exc_info=True)
|
||||
|
||||
|
||||
async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
|
||||
"""Build a ticker -> company_id mapping from the companies table."""
|
||||
rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE")
|
||||
@@ -239,7 +325,53 @@ async def main() -> None:
|
||||
secret_key=config.minio.secret_key,
|
||||
secure=config.minio.secure,
|
||||
)
|
||||
ollama = OllamaClient(config.ollama)
|
||||
|
||||
# Resolve extractor config from DB (active variant override + TTL cache)
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||
resolved_config: ResolvedAgentConfig | None = None
|
||||
extractor_ollama_config = config.ollama
|
||||
try:
|
||||
resolved_config = await resolver.resolve("document-extractor")
|
||||
if resolved_config is not None:
|
||||
extractor_ollama_config = _build_ollama_config_from_resolved(
|
||||
resolved_config, config.ollama,
|
||||
)
|
||||
logger.info(
|
||||
"Extractor using resolved config: model=%s variant=%s",
|
||||
resolved_config.model_name, resolved_config.variant_id,
|
||||
)
|
||||
else:
|
||||
logger.info("No DB config for document-extractor — using env defaults")
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
|
||||
|
||||
ollama = OllamaClient(extractor_ollama_config)
|
||||
|
||||
# Resolve event classifier config separately (may use different model)
|
||||
classifier_resolved: ResolvedAgentConfig | None = None
|
||||
classifier_ollama_config = config.ollama
|
||||
try:
|
||||
classifier_resolved = await resolver.resolve("event-classifier")
|
||||
if classifier_resolved is not None:
|
||||
classifier_ollama_config = _build_ollama_config_from_resolved(
|
||||
classifier_resolved, config.ollama,
|
||||
)
|
||||
logger.info(
|
||||
"Event classifier using resolved config: model=%s variant=%s",
|
||||
classifier_resolved.model_name, classifier_resolved.variant_id,
|
||||
)
|
||||
else:
|
||||
logger.info("No DB config for event-classifier — using extractor config")
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
|
||||
|
||||
# Use a separate OllamaClient for the classifier if it has a different model
|
||||
classifier_ollama: OllamaClient
|
||||
if classifier_ollama_config.model != extractor_ollama_config.model:
|
||||
classifier_ollama = OllamaClient(classifier_ollama_config)
|
||||
else:
|
||||
classifier_ollama = ollama
|
||||
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_EXTRACTION)
|
||||
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||||
@@ -307,6 +439,44 @@ async def main() -> None:
|
||||
refresh_counter += 1
|
||||
if refresh_counter % 100 == 0:
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
# Re-resolve extractor config (picks up active variant swaps)
|
||||
try:
|
||||
resolved_config = await resolver.resolve("document-extractor")
|
||||
if resolved_config is not None:
|
||||
new_ollama_cfg = _build_ollama_config_from_resolved(
|
||||
resolved_config, config.ollama,
|
||||
)
|
||||
if new_ollama_cfg.model != ollama._config.model:
|
||||
logger.info(
|
||||
"Extractor config changed: model=%s variant=%s",
|
||||
resolved_config.model_name, resolved_config.variant_id,
|
||||
)
|
||||
await ollama.close()
|
||||
ollama = OllamaClient(new_ollama_cfg)
|
||||
else:
|
||||
ollama._config = new_ollama_cfg
|
||||
except Exception:
|
||||
logger.warning("Failed to refresh extractor config", exc_info=True)
|
||||
|
||||
# Re-resolve event classifier config
|
||||
try:
|
||||
classifier_resolved = await resolver.resolve("event-classifier")
|
||||
if classifier_resolved is not None:
|
||||
new_cls_cfg = _build_ollama_config_from_resolved(
|
||||
classifier_resolved, config.ollama,
|
||||
)
|
||||
if new_cls_cfg.model != classifier_ollama._config.model:
|
||||
logger.info(
|
||||
"Event classifier config changed: model=%s variant=%s",
|
||||
classifier_resolved.model_name, classifier_resolved.variant_id,
|
||||
)
|
||||
if classifier_ollama is not ollama:
|
||||
await classifier_ollama.close()
|
||||
classifier_ollama = OllamaClient(new_cls_cfg)
|
||||
elif classifier_ollama is ollama and new_cls_cfg.model != ollama._config.model:
|
||||
classifier_ollama = OllamaClient(new_cls_cfg)
|
||||
except Exception:
|
||||
logger.warning("Failed to refresh event-classifier config", exc_info=True)
|
||||
|
||||
# Route macro_event documents to event classification (Requirement 2.1)
|
||||
doc_type = None
|
||||
@@ -320,7 +490,7 @@ async def main() -> None:
|
||||
await _process_macro_classification(
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
ollama=ollama,
|
||||
ollama=classifier_ollama,
|
||||
redis_client=redis_client,
|
||||
document_id=document_id,
|
||||
text=text,
|
||||
@@ -333,10 +503,34 @@ async def main() -> None:
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
try:
|
||||
# Token budget enforcement (Requirement 10.6)
|
||||
if (
|
||||
resolved_config is not None
|
||||
and resolved_config.token_budget > 0
|
||||
and resolved_config.variant_id is not None
|
||||
):
|
||||
budget_exceeded = await _check_token_budget(
|
||||
pool, resolved_config.variant_id, resolved_config.token_budget,
|
||||
)
|
||||
if budget_exceeded:
|
||||
continue
|
||||
|
||||
# Input token limit truncation (Requirement 10.5)
|
||||
extraction_text = text
|
||||
if resolved_config is not None and resolved_config.input_token_limit > 0:
|
||||
# Rough estimate: ~4 chars per token
|
||||
max_chars = resolved_config.input_token_limit * 4
|
||||
if len(extraction_text) > max_chars:
|
||||
extraction_text = extraction_text[:max_chars]
|
||||
logger.info(
|
||||
"Truncated input for doc %s from %d to %d chars (token limit %d)",
|
||||
document_id, len(text), max_chars, resolved_config.input_token_limit,
|
||||
)
|
||||
|
||||
# Pass all tracked tickers so the model can identify any mentioned companies
|
||||
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
||||
extraction_response = await ollama.extract(
|
||||
text,
|
||||
extraction_text,
|
||||
document_id=document_id,
|
||||
known_tickers=all_tickers,
|
||||
)
|
||||
@@ -347,9 +541,34 @@ async def main() -> None:
|
||||
ticker=ticker,
|
||||
extraction_response=extraction_response,
|
||||
company_id_map=company_id_map,
|
||||
document_text_length=len(text),
|
||||
document_text_length=len(extraction_text),
|
||||
)
|
||||
|
||||
# Log to agent_performance_log with variant attribution
|
||||
if resolved_config is not None:
|
||||
output_tokens = 0
|
||||
if extraction_response.attempts:
|
||||
final = extraction_response.attempts[-1]
|
||||
output_tokens = len(final.raw_output) // 4 if final.raw_output else 0
|
||||
await _log_agent_performance(
|
||||
pool,
|
||||
agent_id=resolved_config.agent_id,
|
||||
variant_id=resolved_config.variant_id,
|
||||
document_id=document_id,
|
||||
ticker=ticker,
|
||||
success=extraction_response.success,
|
||||
duration_ms=extraction_response.total_duration_ms,
|
||||
confidence=extraction_response.result.confidence if extraction_response.result else 0.0,
|
||||
retry_count=max(0, len(extraction_response.attempts) - 1),
|
||||
input_tokens=len(extraction_text) // 4,
|
||||
output_tokens=output_tokens,
|
||||
error_message=(
|
||||
extraction_response.attempts[-1].error
|
||||
if extraction_response.attempts and extraction_response.attempts[-1].error
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Enqueue aggregation job for the ticker on success
|
||||
if result.success and ticker:
|
||||
await redis_client.rpush(
|
||||
|
||||
@@ -9,7 +9,8 @@ import asyncpg
|
||||
from minio import Minio
|
||||
|
||||
from services.recommendation.worker import generate_recommendation
|
||||
from services.shared.config import load_config
|
||||
from services.shared.agent_config import AgentConfigResolver
|
||||
from services.shared.config import OllamaConfig, load_config
|
||||
from services.shared.logging import setup_logging
|
||||
from services.shared.redis_keys import QUEUE_RECOMMENDATION, queue_key
|
||||
|
||||
@@ -32,8 +33,33 @@ async def main() -> None:
|
||||
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_RECOMMENDATION)
|
||||
|
||||
# Resolve thesis rewriter config from DB
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||
ollama_config: OllamaConfig | None = None
|
||||
try:
|
||||
resolved = await resolver.resolve("thesis-rewriter")
|
||||
if resolved is not None:
|
||||
ollama_config = OllamaConfig(
|
||||
base_url=config.ollama.base_url,
|
||||
model=resolved.model_name,
|
||||
timeout=resolved.timeout_seconds,
|
||||
max_retries=resolved.max_retries,
|
||||
max_tokens=resolved.max_tokens,
|
||||
)
|
||||
logger.info(
|
||||
"Thesis rewriter enabled: model=%s variant=%s",
|
||||
resolved.model_name, resolved.variant_id,
|
||||
)
|
||||
else:
|
||||
logger.info("No DB config for thesis-rewriter — thesis rewriting disabled")
|
||||
except Exception:
|
||||
logger.warning("Failed to resolve thesis-rewriter config — thesis rewriting disabled", exc_info=True)
|
||||
|
||||
logger.info("Recommendation worker started, polling %s", queue)
|
||||
|
||||
refresh_counter = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await redis_client.lpop(queue)
|
||||
@@ -48,10 +74,30 @@ async def main() -> None:
|
||||
|
||||
logger.info("Processing recommendation job for %s/%s", ticker, window)
|
||||
|
||||
# Refresh resolver every 50 jobs to pick up config changes
|
||||
refresh_counter += 1
|
||||
if refresh_counter % 50 == 0:
|
||||
try:
|
||||
resolved = await resolver.resolve("thesis-rewriter")
|
||||
if resolved is not None:
|
||||
new_config = OllamaConfig(
|
||||
base_url=config.ollama.base_url,
|
||||
model=resolved.model_name,
|
||||
timeout=resolved.timeout_seconds,
|
||||
max_retries=resolved.max_retries,
|
||||
max_tokens=resolved.max_tokens,
|
||||
)
|
||||
if ollama_config is None or new_config.model != ollama_config.model:
|
||||
logger.info("Thesis rewriter config updated: model=%s", resolved.model_name)
|
||||
ollama_config = new_config
|
||||
except Exception:
|
||||
logger.warning("Failed to refresh thesis-rewriter config", exc_info=True)
|
||||
|
||||
try:
|
||||
rec = await generate_recommendation(
|
||||
pool, ticker, window,
|
||||
minio_client=minio_client,
|
||||
ollama_config=ollama_config,
|
||||
)
|
||||
if rec:
|
||||
logger.info(
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests for AgentConfigResolver — validates config resolution logic."""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
||||
|
||||
|
||||
def _make_row(
|
||||
agent_id: str = "agent-1",
|
||||
variant_id: str | None = None,
|
||||
model_name: str = "qwen3.5:9b",
|
||||
system_prompt: str = "test prompt",
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Build a mock DB row for the resolver."""
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"variant_id": variant_id,
|
||||
"model_provider": kwargs.get("model_provider", "ollama"),
|
||||
"model_name": model_name,
|
||||
"system_prompt": system_prompt,
|
||||
"user_prompt_template": kwargs.get("user_prompt_template", ""),
|
||||
"prompt_version": kwargs.get("prompt_version", "v1"),
|
||||
"temperature": kwargs.get("temperature", 0.0),
|
||||
"max_tokens": kwargs.get("max_tokens", 32768),
|
||||
"context_window": kwargs.get("context_window", 0),
|
||||
"input_token_limit": kwargs.get("input_token_limit", 0),
|
||||
"token_budget": kwargs.get("token_budget", 0),
|
||||
"timeout_seconds": kwargs.get("timeout_seconds", 120),
|
||||
"max_retries": kwargs.get("max_retries", 2),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_returns_config_from_db():
|
||||
"""Resolver returns a ResolvedAgentConfig when the DB has a matching agent."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="qwen3.5:9b-fast"))
|
||||
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||
config = await resolver.resolve("document-extractor")
|
||||
|
||||
assert config is not None
|
||||
assert config.model_name == "qwen3.5:9b-fast"
|
||||
assert config.agent_id == "agent-1"
|
||||
assert config.variant_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_returns_variant_when_active():
|
||||
"""Resolver returns variant config when an active variant exists."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(return_value=_make_row(
|
||||
variant_id="variant-1",
|
||||
model_name="llama3.1:8b",
|
||||
system_prompt="variant prompt",
|
||||
))
|
||||
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||
config = await resolver.resolve("document-extractor")
|
||||
|
||||
assert config is not None
|
||||
assert config.variant_id == "variant-1"
|
||||
assert config.model_name == "llama3.1:8b"
|
||||
assert config.system_prompt == "variant prompt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_returns_none_when_not_found():
|
||||
"""Resolver returns None when no agent matches the slug."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(return_value=None)
|
||||
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||
config = await resolver.resolve("nonexistent-agent")
|
||||
|
||||
assert config is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_returns_none_on_db_error():
|
||||
"""Resolver returns None and doesn't crash when DB query fails."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(side_effect=Exception("connection refused"))
|
||||
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||||
config = await resolver.resolve("document-extractor")
|
||||
|
||||
assert config is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_uses_cache_within_ttl():
|
||||
"""Resolver caches results and doesn't re-query within TTL."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="cached-model"))
|
||||
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=300)
|
||||
|
||||
config1 = await resolver.resolve("document-extractor")
|
||||
config2 = await resolver.resolve("document-extractor")
|
||||
|
||||
assert config1 is not None
|
||||
assert config2 is not None
|
||||
assert config1.model_name == "cached-model"
|
||||
assert config2.model_name == "cached-model"
|
||||
# Should only query DB once
|
||||
assert pool.fetchrow.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_refreshes_after_ttl():
|
||||
"""Resolver re-queries DB after TTL expires."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="model-v1"))
|
||||
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=0) # 0 TTL = always expired
|
||||
|
||||
config1 = await resolver.resolve("document-extractor")
|
||||
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="model-v2"))
|
||||
config2 = await resolver.resolve("document-extractor")
|
||||
|
||||
assert config1 is not None
|
||||
assert config1.model_name == "model-v1"
|
||||
assert config2 is not None
|
||||
assert config2.model_name == "model-v2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_clears_cache():
|
||||
"""invalidate() forces the next resolve to re-query."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="original"))
|
||||
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=300)
|
||||
|
||||
config1 = await resolver.resolve("document-extractor")
|
||||
assert config1 is not None
|
||||
|
||||
resolver.invalidate("document-extractor")
|
||||
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="updated"))
|
||||
|
||||
config2 = await resolver.resolve("document-extractor")
|
||||
assert config2 is not None
|
||||
assert config2.model_name == "updated"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_all_clears_entire_cache():
|
||||
"""invalidate(None) clears all cached entries."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(return_value=_make_row())
|
||||
|
||||
resolver = AgentConfigResolver(pool, ttl_seconds=300)
|
||||
|
||||
await resolver.resolve("agent-a")
|
||||
await resolver.resolve("agent-b")
|
||||
assert len(resolver._cache) == 2
|
||||
|
||||
resolver.invalidate()
|
||||
assert len(resolver._cache) == 0
|
||||
Reference in New Issue
Block a user