feat: wire all 3 agents to DB config resolver

- Recommendation worker now resolves thesis-rewriter config from DB
  and passes ollama_config to generate_recommendation. Thesis rewriting
  is now active when the thesis-rewriter agent exists in ai_agents.
  Refreshes config every 50 jobs.

- Event classifier now resolves its own config separately from the
  document extractor via 'event-classifier' slug. Uses a separate
  OllamaClient when the model differs from the extractor. Refreshes
  alongside the extractor every 100 jobs.

- Document extractor was already wired (existing code).

- Added 8 unit tests for AgentConfigResolver covering: DB resolution,
  variant override, not-found, DB errors, TTL caching, cache refresh,
  and invalidation.
This commit is contained in:
Celes Renata
2026-04-17 02:59:40 +00:00
parent c501ccea40
commit 6179382d1e
3 changed files with 436 additions and 6 deletions
+224 -5
View File
@@ -18,7 +18,8 @@ from services.aggregation.interpolation import (
from services.extractor.client import OllamaClient
from services.extractor.event_classifier import classify_global_event
from services.extractor.worker import persist_extraction
from services.shared.config import load_config
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
from services.shared.config import OllamaConfig, load_config
from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import (
QUEUE_AGGREGATION,
@@ -30,6 +31,91 @@ from services.shared.redis_keys import (
logger = logging.getLogger("extractor_main")
def _build_ollama_config_from_resolved(
resolved: ResolvedAgentConfig,
base_config: OllamaConfig,
) -> OllamaConfig:
"""Build an OllamaConfig from a ResolvedAgentConfig, preserving base retry settings."""
return OllamaConfig(
base_url=base_config.base_url,
model=resolved.model_name,
timeout=resolved.timeout_seconds,
max_retries=resolved.max_retries,
retry_base_delay=base_config.retry_base_delay,
retry_max_delay=base_config.retry_max_delay,
retry_backoff_multiplier=base_config.retry_backoff_multiplier,
max_tokens=resolved.max_tokens,
stall_timeout=base_config.stall_timeout,
loop_window=base_config.loop_window,
loop_threshold=base_config.loop_threshold,
context_window=resolved.context_window,
)
async def _check_token_budget(
pool: asyncpg.Pool,
variant_id: str,
token_budget: int,
) -> bool:
"""Check if a variant has exceeded its hourly token budget.
Returns True if the budget is exceeded and invocation should be skipped.
"""
row = await pool.fetchrow(
"""SELECT COALESCE(SUM(input_tokens + output_tokens), 0) AS total_tokens
FROM agent_performance_log
WHERE variant_id = $1
AND recorded_at >= NOW() - INTERVAL '1 hour'""",
variant_id,
)
used = int(row["total_tokens"]) if row else 0
if used >= token_budget:
logger.warning(
"Token budget exceeded for variant %s: used %d / budget %d — skipping invocation",
variant_id, used, token_budget,
)
return True
return False
async def _log_agent_performance(
pool: asyncpg.Pool,
*,
agent_id: str,
variant_id: str | None = None,
document_id: str = "",
ticker: str = "",
success: bool = False,
duration_ms: int = 0,
confidence: float = 0.0,
retry_count: int = 0,
input_tokens: int = 0,
output_tokens: int = 0,
error_message: str | None = None,
) -> None:
"""Insert a row into agent_performance_log with optional variant attribution."""
try:
await pool.execute(
"""INSERT INTO agent_performance_log
(agent_id, variant_id, document_id, ticker, success, duration_ms,
confidence, retry_count, input_tokens, output_tokens, error_message)
VALUES ($1::uuid, $2::uuid, $3::uuid, $4, $5, $6, $7, $8, $9, $10, $11)""",
agent_id,
variant_id,
document_id if document_id else None,
ticker,
success,
duration_ms,
confidence,
retry_count,
input_tokens,
output_tokens,
error_message,
)
except Exception:
logger.warning("Failed to log agent performance", exc_info=True)
async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
"""Build a ticker -> company_id mapping from the companies table."""
rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE")
@@ -239,7 +325,53 @@ async def main() -> None:
secret_key=config.minio.secret_key,
secure=config.minio.secure,
)
ollama = OllamaClient(config.ollama)
# Resolve extractor config from DB (active variant override + TTL cache)
resolver = AgentConfigResolver(pool, ttl_seconds=60)
resolved_config: ResolvedAgentConfig | None = None
extractor_ollama_config = config.ollama
try:
resolved_config = await resolver.resolve("document-extractor")
if resolved_config is not None:
extractor_ollama_config = _build_ollama_config_from_resolved(
resolved_config, config.ollama,
)
logger.info(
"Extractor using resolved config: model=%s variant=%s",
resolved_config.model_name, resolved_config.variant_id,
)
else:
logger.info("No DB config for document-extractor — using env defaults")
except Exception:
logger.warning("Failed to resolve extractor config — using env defaults", exc_info=True)
ollama = OllamaClient(extractor_ollama_config)
# Resolve event classifier config separately (may use different model)
classifier_resolved: ResolvedAgentConfig | None = None
classifier_ollama_config = config.ollama
try:
classifier_resolved = await resolver.resolve("event-classifier")
if classifier_resolved is not None:
classifier_ollama_config = _build_ollama_config_from_resolved(
classifier_resolved, config.ollama,
)
logger.info(
"Event classifier using resolved config: model=%s variant=%s",
classifier_resolved.model_name, classifier_resolved.variant_id,
)
else:
logger.info("No DB config for event-classifier — using extractor config")
except Exception:
logger.warning("Failed to resolve event-classifier config — using extractor config", exc_info=True)
# Use a separate OllamaClient for the classifier if it has a different model
classifier_ollama: OllamaClient
if classifier_ollama_config.model != extractor_ollama_config.model:
classifier_ollama = OllamaClient(classifier_ollama_config)
else:
classifier_ollama = ollama
redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION)
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
@@ -307,6 +439,44 @@ async def main() -> None:
refresh_counter += 1
if refresh_counter % 100 == 0:
company_id_map = await _build_company_id_map(pool)
# Re-resolve extractor config (picks up active variant swaps)
try:
resolved_config = await resolver.resolve("document-extractor")
if resolved_config is not None:
new_ollama_cfg = _build_ollama_config_from_resolved(
resolved_config, config.ollama,
)
if new_ollama_cfg.model != ollama._config.model:
logger.info(
"Extractor config changed: model=%s variant=%s",
resolved_config.model_name, resolved_config.variant_id,
)
await ollama.close()
ollama = OllamaClient(new_ollama_cfg)
else:
ollama._config = new_ollama_cfg
except Exception:
logger.warning("Failed to refresh extractor config", exc_info=True)
# Re-resolve event classifier config
try:
classifier_resolved = await resolver.resolve("event-classifier")
if classifier_resolved is not None:
new_cls_cfg = _build_ollama_config_from_resolved(
classifier_resolved, config.ollama,
)
if new_cls_cfg.model != classifier_ollama._config.model:
logger.info(
"Event classifier config changed: model=%s variant=%s",
classifier_resolved.model_name, classifier_resolved.variant_id,
)
if classifier_ollama is not ollama:
await classifier_ollama.close()
classifier_ollama = OllamaClient(new_cls_cfg)
elif classifier_ollama is ollama and new_cls_cfg.model != ollama._config.model:
classifier_ollama = OllamaClient(new_cls_cfg)
except Exception:
logger.warning("Failed to refresh event-classifier config", exc_info=True)
# Route macro_event documents to event classification (Requirement 2.1)
doc_type = None
@@ -320,7 +490,7 @@ async def main() -> None:
await _process_macro_classification(
pool=pool,
minio_client=minio_client,
ollama=ollama,
ollama=classifier_ollama,
redis_client=redis_client,
document_id=document_id,
text=text,
@@ -333,10 +503,34 @@ async def main() -> None:
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
try:
# Token budget enforcement (Requirement 10.6)
if (
resolved_config is not None
and resolved_config.token_budget > 0
and resolved_config.variant_id is not None
):
budget_exceeded = await _check_token_budget(
pool, resolved_config.variant_id, resolved_config.token_budget,
)
if budget_exceeded:
continue
# Input token limit truncation (Requirement 10.5)
extraction_text = text
if resolved_config is not None and resolved_config.input_token_limit > 0:
# Rough estimate: ~4 chars per token
max_chars = resolved_config.input_token_limit * 4
if len(extraction_text) > max_chars:
extraction_text = extraction_text[:max_chars]
logger.info(
"Truncated input for doc %s from %d to %d chars (token limit %d)",
document_id, len(text), max_chars, resolved_config.input_token_limit,
)
# Pass all tracked tickers so the model can identify any mentioned companies
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
extraction_response = await ollama.extract(
text,
extraction_text,
document_id=document_id,
known_tickers=all_tickers,
)
@@ -347,9 +541,34 @@ async def main() -> None:
ticker=ticker,
extraction_response=extraction_response,
company_id_map=company_id_map,
document_text_length=len(text),
document_text_length=len(extraction_text),
)
# Log to agent_performance_log with variant attribution
if resolved_config is not None:
output_tokens = 0
if extraction_response.attempts:
final = extraction_response.attempts[-1]
output_tokens = len(final.raw_output) // 4 if final.raw_output else 0
await _log_agent_performance(
pool,
agent_id=resolved_config.agent_id,
variant_id=resolved_config.variant_id,
document_id=document_id,
ticker=ticker,
success=extraction_response.success,
duration_ms=extraction_response.total_duration_ms,
confidence=extraction_response.result.confidence if extraction_response.result else 0.0,
retry_count=max(0, len(extraction_response.attempts) - 1),
input_tokens=len(extraction_text) // 4,
output_tokens=output_tokens,
error_message=(
extraction_response.attempts[-1].error
if extraction_response.attempts and extraction_response.attempts[-1].error
else None
),
)
# Enqueue aggregation job for the ticker on success
if result.success and ticker:
await redis_client.rpush(