161 lines
5.7 KiB
Python
161 lines
5.7 KiB
Python
"""Agent configuration resolver with active-variant override and TTL cache.
|
||
|
||
Resolves runtime configuration for AI agents from the database, preferring
|
||
the active variant's values when one exists. All three agent services
|
||
(extractor, event classifier, thesis rewriter) share this module instead
|
||
of duplicating resolution logic.
|
||
|
||
Requirements: 4.3, 4.4, 9.1–9.5, 10.4–10.6
|
||
Design: Config Resolution Module
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import time
|
||
from dataclasses import dataclass
|
||
|
||
import asyncpg
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Resolved config dataclass
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@dataclass(frozen=True, slots=True)
|
||
class ResolvedAgentConfig:
|
||
"""Runtime configuration resolved from DB agent + optional active variant."""
|
||
|
||
agent_id: str
|
||
variant_id: str | None
|
||
model_provider: str
|
||
model_name: str
|
||
system_prompt: str
|
||
user_prompt_template: str
|
||
prompt_version: str
|
||
temperature: float
|
||
max_tokens: int
|
||
context_window: int
|
||
input_token_limit: int
|
||
token_budget: int
|
||
timeout_seconds: int
|
||
max_retries: int
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SQL: resolve agent config, preferring active variant via COALESCE
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_RESOLVE_SQL = """\
|
||
SELECT a.id AS agent_id,
|
||
v.id AS variant_id,
|
||
COALESCE(v.model_provider, a.model_provider) AS model_provider,
|
||
COALESCE(v.model_name, a.model_name) AS model_name,
|
||
COALESCE(v.system_prompt, a.system_prompt) AS system_prompt,
|
||
COALESCE(v.user_prompt_template,a.user_prompt_template) AS user_prompt_template,
|
||
COALESCE(v.prompt_version, a.prompt_version) AS prompt_version,
|
||
COALESCE(v.temperature, a.temperature) AS temperature,
|
||
COALESCE(v.max_tokens, a.max_tokens) AS max_tokens,
|
||
COALESCE(v.context_window, 0) AS context_window,
|
||
COALESCE(v.input_token_limit, 0) AS input_token_limit,
|
||
COALESCE(v.token_budget, 0) AS token_budget,
|
||
COALESCE(v.timeout_seconds, a.timeout_seconds) AS timeout_seconds,
|
||
COALESCE(v.max_retries, a.max_retries) AS max_retries
|
||
FROM ai_agents a
|
||
LEFT JOIN agent_variants v
|
||
ON v.agent_id = a.id AND v.is_active = TRUE
|
||
WHERE a.slug = $1
|
||
AND a.active = TRUE
|
||
"""
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Resolver with TTL-based in-memory cache
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class AgentConfigResolver:
|
||
"""Resolves agent configuration from DB with active variant override and TTL cache.
|
||
|
||
Usage::
|
||
|
||
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
||
config = await resolver.resolve("document-extractor")
|
||
if config is None:
|
||
# fall back to env-var OllamaConfig defaults
|
||
...
|
||
"""
|
||
|
||
def __init__(self, pool: asyncpg.Pool, ttl_seconds: int = 60) -> None:
|
||
self._pool = pool
|
||
self._ttl = ttl_seconds
|
||
self._cache: dict[str, tuple[float, ResolvedAgentConfig]] = {}
|
||
|
||
async def resolve(self, agent_slug: str) -> ResolvedAgentConfig | None:
|
||
"""Resolve config for an agent slug, preferring active variant if present.
|
||
|
||
Returns ``None`` and logs a warning when the agent slug is not found
|
||
or the database query fails. Callers should fall back to env-var
|
||
defaults in that case.
|
||
"""
|
||
now = time.monotonic()
|
||
|
||
# Check cache
|
||
cached = self._cache.get(agent_slug)
|
||
if cached is not None:
|
||
ts, config = cached
|
||
if (now - ts) < self._ttl:
|
||
return config
|
||
# Expired — remove stale entry before re-querying
|
||
del self._cache[agent_slug]
|
||
|
||
# Query database
|
||
try:
|
||
row = await self._pool.fetchrow(_RESOLVE_SQL, agent_slug)
|
||
except Exception:
|
||
logger.warning(
|
||
"Failed to resolve agent config for %s from database",
|
||
agent_slug,
|
||
exc_info=True,
|
||
)
|
||
return None
|
||
|
||
if row is None:
|
||
logger.warning(
|
||
"No active agent found for slug %r",
|
||
agent_slug,
|
||
)
|
||
return None
|
||
|
||
config = ResolvedAgentConfig(
|
||
agent_id=str(row["agent_id"]),
|
||
variant_id=str(row["variant_id"]) if row["variant_id"] else None,
|
||
model_provider=row["model_provider"],
|
||
model_name=row["model_name"],
|
||
system_prompt=row["system_prompt"],
|
||
user_prompt_template=row["user_prompt_template"],
|
||
prompt_version=row["prompt_version"],
|
||
temperature=float(row["temperature"]),
|
||
max_tokens=int(row["max_tokens"]),
|
||
context_window=int(row["context_window"]),
|
||
input_token_limit=int(row["input_token_limit"]),
|
||
token_budget=int(row["token_budget"]),
|
||
timeout_seconds=int(row["timeout_seconds"]),
|
||
max_retries=int(row["max_retries"]),
|
||
)
|
||
self._cache[agent_slug] = (now, config)
|
||
return config
|
||
|
||
def invalidate(self, agent_slug: str | None = None) -> None:
|
||
"""Drop cached entries.
|
||
|
||
If *agent_slug* is given, only that entry is removed.
|
||
If ``None``, the entire cache is cleared.
|
||
"""
|
||
if agent_slug is None:
|
||
self._cache.clear()
|
||
else:
|
||
self._cache.pop(agent_slug, None)
|