Files
stonks-oracle/services/shared/agent_config.py
T

161 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Agent configuration resolver with active-variant override and TTL cache.
Resolves runtime configuration for AI agents from the database, preferring
the active variant's values when one exists. All three agent services
(extractor, event classifier, thesis rewriter) share this module instead
of duplicating resolution logic.
Requirements: 4.3, 4.4, 9.19.5, 10.410.6
Design: Config Resolution Module
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
import asyncpg
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Resolved config dataclass
# ---------------------------------------------------------------------------
@dataclass(frozen=True, slots=True)
class ResolvedAgentConfig:
"""Runtime configuration resolved from DB agent + optional active variant."""
agent_id: str
variant_id: str | None
model_provider: str
model_name: str
system_prompt: str
user_prompt_template: str
prompt_version: str
temperature: float
max_tokens: int
context_window: int
input_token_limit: int
token_budget: int
timeout_seconds: int
max_retries: int
# ---------------------------------------------------------------------------
# SQL: resolve agent config, preferring active variant via COALESCE
# ---------------------------------------------------------------------------
_RESOLVE_SQL = """\
SELECT a.id AS agent_id,
v.id AS variant_id,
COALESCE(v.model_provider, a.model_provider) AS model_provider,
COALESCE(v.model_name, a.model_name) AS model_name,
COALESCE(v.system_prompt, a.system_prompt) AS system_prompt,
COALESCE(v.user_prompt_template,a.user_prompt_template) AS user_prompt_template,
COALESCE(v.prompt_version, a.prompt_version) AS prompt_version,
COALESCE(v.temperature, a.temperature) AS temperature,
COALESCE(v.max_tokens, a.max_tokens) AS max_tokens,
COALESCE(v.context_window, 0) AS context_window,
COALESCE(v.input_token_limit, 0) AS input_token_limit,
COALESCE(v.token_budget, 0) AS token_budget,
COALESCE(v.timeout_seconds, a.timeout_seconds) AS timeout_seconds,
COALESCE(v.max_retries, a.max_retries) AS max_retries
FROM ai_agents a
LEFT JOIN agent_variants v
ON v.agent_id = a.id AND v.is_active = TRUE
WHERE a.slug = $1
AND a.active = TRUE
"""
# ---------------------------------------------------------------------------
# Resolver with TTL-based in-memory cache
# ---------------------------------------------------------------------------
class AgentConfigResolver:
"""Resolves agent configuration from DB with active variant override and TTL cache.
Usage::
resolver = AgentConfigResolver(pool, ttl_seconds=60)
config = await resolver.resolve("document-extractor")
if config is None:
# fall back to env-var OllamaConfig defaults
...
"""
def __init__(self, pool: asyncpg.Pool, ttl_seconds: int = 60) -> None:
self._pool = pool
self._ttl = ttl_seconds
self._cache: dict[str, tuple[float, ResolvedAgentConfig]] = {}
async def resolve(self, agent_slug: str) -> ResolvedAgentConfig | None:
"""Resolve config for an agent slug, preferring active variant if present.
Returns ``None`` and logs a warning when the agent slug is not found
or the database query fails. Callers should fall back to env-var
defaults in that case.
"""
now = time.monotonic()
# Check cache
cached = self._cache.get(agent_slug)
if cached is not None:
ts, config = cached
if (now - ts) < self._ttl:
return config
# Expired — remove stale entry before re-querying
del self._cache[agent_slug]
# Query database
try:
row = await self._pool.fetchrow(_RESOLVE_SQL, agent_slug)
except Exception:
logger.warning(
"Failed to resolve agent config for %s from database",
agent_slug,
exc_info=True,
)
return None
if row is None:
logger.warning(
"No active agent found for slug %r",
agent_slug,
)
return None
config = ResolvedAgentConfig(
agent_id=str(row["agent_id"]),
variant_id=str(row["variant_id"]) if row["variant_id"] else None,
model_provider=row["model_provider"],
model_name=row["model_name"],
system_prompt=row["system_prompt"],
user_prompt_template=row["user_prompt_template"],
prompt_version=row["prompt_version"],
temperature=float(row["temperature"]),
max_tokens=int(row["max_tokens"]),
context_window=int(row["context_window"]),
input_token_limit=int(row["input_token_limit"]),
token_budget=int(row["token_budget"]),
timeout_seconds=int(row["timeout_seconds"]),
max_retries=int(row["max_retries"]),
)
self._cache[agent_slug] = (now, config)
return config
def invalidate(self, agent_slug: str | None = None) -> None:
"""Drop cached entries.
If *agent_slug* is given, only that entry is removed.
If ``None``, the entire cache is cleared.
"""
if agent_slug is None:
self._cache.clear()
else:
self._cache.pop(agent_slug, None)