6179382d1e
- Recommendation worker now resolves thesis-rewriter config from DB and passes ollama_config to generate_recommendation. Thesis rewriting is now active when the thesis-rewriter agent exists in ai_agents. Refreshes config every 50 jobs. - Event classifier now resolves its own config separately from the document extractor via 'event-classifier' slug. Uses a separate OllamaClient when the model differs from the extractor. Refreshes alongside the extractor every 100 jobs. - Document extractor was already wired (existing code). - Added 8 unit tests for AgentConfigResolver covering: DB resolution, variant override, not-found, DB errors, TTL caching, cache refresh, and invalidation.
166 lines
5.4 KiB
Python
166 lines
5.4 KiB
Python
"""Tests for AgentConfigResolver — validates config resolution logic."""
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from services.shared.agent_config import AgentConfigResolver, ResolvedAgentConfig
|
|
|
|
|
|
def _make_row(
|
|
agent_id: str = "agent-1",
|
|
variant_id: str | None = None,
|
|
model_name: str = "qwen3.5:9b",
|
|
system_prompt: str = "test prompt",
|
|
**kwargs,
|
|
) -> dict:
|
|
"""Build a mock DB row for the resolver."""
|
|
return {
|
|
"agent_id": agent_id,
|
|
"variant_id": variant_id,
|
|
"model_provider": kwargs.get("model_provider", "ollama"),
|
|
"model_name": model_name,
|
|
"system_prompt": system_prompt,
|
|
"user_prompt_template": kwargs.get("user_prompt_template", ""),
|
|
"prompt_version": kwargs.get("prompt_version", "v1"),
|
|
"temperature": kwargs.get("temperature", 0.0),
|
|
"max_tokens": kwargs.get("max_tokens", 32768),
|
|
"context_window": kwargs.get("context_window", 0),
|
|
"input_token_limit": kwargs.get("input_token_limit", 0),
|
|
"token_budget": kwargs.get("token_budget", 0),
|
|
"timeout_seconds": kwargs.get("timeout_seconds", 120),
|
|
"max_retries": kwargs.get("max_retries", 2),
|
|
}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_returns_config_from_db():
|
|
"""Resolver returns a ResolvedAgentConfig when the DB has a matching agent."""
|
|
pool = AsyncMock()
|
|
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="qwen3.5:9b-fast"))
|
|
|
|
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
|
config = await resolver.resolve("document-extractor")
|
|
|
|
assert config is not None
|
|
assert config.model_name == "qwen3.5:9b-fast"
|
|
assert config.agent_id == "agent-1"
|
|
assert config.variant_id is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_returns_variant_when_active():
|
|
"""Resolver returns variant config when an active variant exists."""
|
|
pool = AsyncMock()
|
|
pool.fetchrow = AsyncMock(return_value=_make_row(
|
|
variant_id="variant-1",
|
|
model_name="llama3.1:8b",
|
|
system_prompt="variant prompt",
|
|
))
|
|
|
|
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
|
config = await resolver.resolve("document-extractor")
|
|
|
|
assert config is not None
|
|
assert config.variant_id == "variant-1"
|
|
assert config.model_name == "llama3.1:8b"
|
|
assert config.system_prompt == "variant prompt"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_returns_none_when_not_found():
|
|
"""Resolver returns None when no agent matches the slug."""
|
|
pool = AsyncMock()
|
|
pool.fetchrow = AsyncMock(return_value=None)
|
|
|
|
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
|
config = await resolver.resolve("nonexistent-agent")
|
|
|
|
assert config is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_returns_none_on_db_error():
|
|
"""Resolver returns None and doesn't crash when DB query fails."""
|
|
pool = AsyncMock()
|
|
pool.fetchrow = AsyncMock(side_effect=Exception("connection refused"))
|
|
|
|
resolver = AgentConfigResolver(pool, ttl_seconds=60)
|
|
config = await resolver.resolve("document-extractor")
|
|
|
|
assert config is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_uses_cache_within_ttl():
|
|
"""Resolver caches results and doesn't re-query within TTL."""
|
|
pool = AsyncMock()
|
|
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="cached-model"))
|
|
|
|
resolver = AgentConfigResolver(pool, ttl_seconds=300)
|
|
|
|
config1 = await resolver.resolve("document-extractor")
|
|
config2 = await resolver.resolve("document-extractor")
|
|
|
|
assert config1 is not None
|
|
assert config2 is not None
|
|
assert config1.model_name == "cached-model"
|
|
assert config2.model_name == "cached-model"
|
|
# Should only query DB once
|
|
assert pool.fetchrow.call_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resolve_refreshes_after_ttl():
|
|
"""Resolver re-queries DB after TTL expires."""
|
|
pool = AsyncMock()
|
|
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="model-v1"))
|
|
|
|
resolver = AgentConfigResolver(pool, ttl_seconds=0) # 0 TTL = always expired
|
|
|
|
config1 = await resolver.resolve("document-extractor")
|
|
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="model-v2"))
|
|
config2 = await resolver.resolve("document-extractor")
|
|
|
|
assert config1 is not None
|
|
assert config1.model_name == "model-v1"
|
|
assert config2 is not None
|
|
assert config2.model_name == "model-v2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalidate_clears_cache():
|
|
"""invalidate() forces the next resolve to re-query."""
|
|
pool = AsyncMock()
|
|
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="original"))
|
|
|
|
resolver = AgentConfigResolver(pool, ttl_seconds=300)
|
|
|
|
config1 = await resolver.resolve("document-extractor")
|
|
assert config1 is not None
|
|
|
|
resolver.invalidate("document-extractor")
|
|
pool.fetchrow = AsyncMock(return_value=_make_row(model_name="updated"))
|
|
|
|
config2 = await resolver.resolve("document-extractor")
|
|
assert config2 is not None
|
|
assert config2.model_name == "updated"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalidate_all_clears_entire_cache():
|
|
"""invalidate(None) clears all cached entries."""
|
|
pool = AsyncMock()
|
|
pool.fetchrow = AsyncMock(return_value=_make_row())
|
|
|
|
resolver = AgentConfigResolver(pool, ttl_seconds=300)
|
|
|
|
await resolver.resolve("agent-a")
|
|
await resolver.resolve("agent-b")
|
|
assert len(resolver._cache) == 2
|
|
|
|
resolver.invalidate()
|
|
assert len(resolver._cache) == 0
|