feat: risk tier selector on Trading page + confidence filter on Recommendations

- Trading page: added conservative/moderate/aggressive selector that
  updates the trading engine config via PUT /api/trading/config
- Recommendations page: added risk tier dropdown that defaults to the
  engine's current tier and filters recs by the tier's min_confidence
- Backend: added min_confidence query param to GET /api/recommendations
- Risk tier thresholds: conservative ≥0.75, moderate ≥0.55, aggressive ≥0.40
This commit is contained in:
Celes Renata
2026-04-17 05:08:54 +00:00
parent 49e3955fab
commit 734bf001a7
4 changed files with 532 additions and 5 deletions
+2 -1
View File
@@ -299,12 +299,13 @@ export interface RecommendationDetail extends Recommendation {
risk_evaluation: { id: string; eligible: boolean; allowed_mode: string; rejection_reasons: string[] | null; risk_checks: Record<string, unknown> | null; evaluated_at: string } | null; risk_evaluation: { id: string; eligible: boolean; allowed_mode: string; rejection_reasons: string[] | null; risk_checks: Record<string, unknown> | null; evaluated_at: string } | null;
} }
export function useRecommendations(params?: { ticker?: string; action?: string; mode?: string; since?: string; limit?: number; offset?: number }) { export function useRecommendations(params?: { ticker?: string; action?: string; mode?: string; since?: string; min_confidence?: number; limit?: number; offset?: number }) {
const qs = new URLSearchParams(); const qs = new URLSearchParams();
if (params?.ticker) qs.set('ticker', params.ticker); if (params?.ticker) qs.set('ticker', params.ticker);
if (params?.action) qs.set('action', params.action); if (params?.action) qs.set('action', params.action);
if (params?.mode) qs.set('mode', params.mode); if (params?.mode) qs.set('mode', params.mode);
if (params?.since) qs.set('since', params.since); if (params?.since) qs.set('since', params.since);
if (params?.min_confidence != null) qs.set('min_confidence', String(params.min_confidence));
if (params?.limit) qs.set('limit', String(params.limit)); if (params?.limit) qs.set('limit', String(params.limit));
if (params?.offset) qs.set('offset', String(params.offset)); if (params?.offset) qs.set('offset', String(params.offset));
const path = `/api/recommendations${qs.toString() ? '?' + qs : ''}`; const path = `/api/recommendations${qs.toString() ? '?' + qs : ''}`;
+38 -2
View File
@@ -1,14 +1,33 @@
import { useState } from 'react'; import { useState } from 'react';
import { useNavigate } from '@tanstack/react-router'; import { useNavigate } from '@tanstack/react-router';
import { useRecommendations } from '../api/hooks'; import { useRecommendations } from '../api/hooks';
import { useTradingStatus } from '../api/tradingHooks';
import { DataTable, type Column } from '../components/DataTable'; import { DataTable, type Column } from '../components/DataTable';
import { StatusBadge, ConfidenceBar, LoadingSpinner, TickerFilter } from '../components/ui'; import { StatusBadge, ConfidenceBar, LoadingSpinner, TickerFilter } from '../components/ui';
import type { Recommendation } from '../api/hooks'; import type { Recommendation } from '../api/hooks';
const RISK_TIER_CONFIDENCE: Record<string, number> = {
conservative: 0.75,
moderate: 0.55,
aggressive: 0.40,
};
export function RecommendationsPage() { export function RecommendationsPage() {
const navigate = useNavigate(); const navigate = useNavigate();
const [ticker, setTicker] = useState(''); const [ticker, setTicker] = useState('');
const { data, isLoading } = useRecommendations({ ticker: ticker || undefined, limit: 100 }); const { data: tradingStatus } = useTradingStatus();
const engineTier = tradingStatus?.risk_tier ?? 'moderate';
const [riskTier, setRiskTier] = useState<string | null>(null);
// Use engine tier as default, allow override
const activeTier = riskTier ?? engineTier;
const minConfidence = RISK_TIER_CONFIDENCE[activeTier] ?? 0.55;
const { data, isLoading } = useRecommendations({
ticker: ticker || undefined,
min_confidence: minConfidence,
limit: 100,
});
const columns: Column<Recommendation>[] = [ const columns: Column<Recommendation>[] = [
{ key: 'ticker', header: 'Ticker', className: 'font-mono font-semibold text-brand-300' }, { key: 'ticker', header: 'Ticker', className: 'font-mono font-semibold text-brand-300' },
@@ -25,7 +44,24 @@ export function RecommendationsPage() {
<div> <div>
<div className="mb-4 flex items-center justify-between"> <div className="mb-4 flex items-center justify-between">
<h1 className="text-xl font-semibold text-gray-100">Recommendations</h1> <h1 className="text-xl font-semibold text-gray-100">Recommendations</h1>
<TickerFilter value={ticker} onChange={setTicker} /> <div className="flex items-center gap-3">
<div className="flex items-center gap-2">
<label htmlFor="risk-tier-filter" className="text-xs text-gray-500">Risk Tier</label>
<select
id="risk-tier-filter"
value={activeTier}
onChange={(e) => setRiskTier(e.target.value)}
className="rounded-md border border-surface-700 bg-surface-900 px-2 py-1.5 text-sm text-gray-200 focus:border-brand-500 focus:outline-none"
>
{Object.entries(RISK_TIER_CONFIDENCE).map(([tier, conf]) => (
<option key={tier} value={tier}>
{tier.charAt(0).toUpperCase() + tier.slice(1)} ({conf})
</option>
))}
</select>
</div>
<TickerFilter value={ticker} onChange={setTicker} />
</div>
</div> </div>
<DataTable<Recommendation> <DataTable<Recommendation>
data={data ?? []} data={data ?? []}
+38 -1
View File
@@ -10,7 +10,7 @@ import {
useCompetitiveStatus, useCompetitiveStatus,
useToggleCompetitive, useToggleCompetitive,
} from '../api/hooks'; } from '../api/hooks';
import { useResetPaperTrading } from '../api/tradingHooks'; import { useResetPaperTrading, useTradingStatus, useUpdateTradingConfig } from '../api/tradingHooks';
import { StatusBadge, LoadingSpinner, Card } from '../components/ui'; import { StatusBadge, LoadingSpinner, Card } from '../components/ui';
export function TradingPage() { export function TradingPage() {
@@ -21,6 +21,8 @@ export function TradingPage() {
const { data: competitiveStatus } = useCompetitiveStatus(); const { data: competitiveStatus } = useCompetitiveStatus();
const setMode = useSetTradingMode(); const setMode = useSetTradingMode();
const resetTrading = useResetPaperTrading(); const resetTrading = useResetPaperTrading();
const { data: tradingStatus } = useTradingStatus();
const updateConfig = useUpdateTradingConfig();
const reviewApproval = useReviewApproval(); const reviewApproval = useReviewApproval();
const toggleMacro = useToggleMacro(); const toggleMacro = useToggleMacro();
const toggleCompetitive = useToggleCompetitive(); const toggleCompetitive = useToggleCompetitive();
@@ -95,6 +97,41 @@ export function TradingPage() {
isResetting={resetTrading.isPending} isResetting={resetTrading.isPending}
/> />
{/* Risk Tier */}
<Card>
<h2 className="mb-3 text-sm font-medium text-gray-400">Risk Tier</h2>
<p className="mb-3 text-[10px] text-gray-600">
Controls confidence gates, position sizing, and portfolio heat limits for the trading engine.
</p>
<div className="flex items-center gap-3">
{(['conservative', 'moderate', 'aggressive'] as const).map((tier) => {
const currentTier = tradingStatus?.risk_tier ?? 'moderate';
const descriptions: Record<string, string> = {
conservative: 'Min confidence 0.75, max 5% position, 10% heat',
moderate: 'Min confidence 0.55, max 10% position, 20% heat',
aggressive: 'Min confidence 0.40, max 15% position, 30% heat',
};
return (
<button
key={tier}
onClick={() => {
if (tier !== currentTier) updateConfig.mutate({ risk_tier: tier });
}}
className={`rounded-md px-4 py-2 text-sm font-medium capitalize transition-colors ${
currentTier === tier
? 'bg-brand-600 text-white'
: 'border border-surface-700 bg-surface-900 text-gray-400 hover:bg-surface-800'
}`}
aria-pressed={currentTier === tier}
title={descriptions[tier]}
>
{tier}
</button>
);
})}
</div>
</Card>
{/* Macro Signal Layer Toggle */} {/* Macro Signal Layer Toggle */}
<Card> <Card>
<h2 className="mb-3 text-sm font-medium text-gray-400">Macro Signal Layer</h2> <h2 className="mb-3 text-sm font-medium text-gray-400">Macro Signal Layer</h2>
+454 -1
View File
@@ -15,6 +15,7 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging import logging
import re
import time as _time import time as _time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import asdict from dataclasses import asdict
@@ -548,6 +549,7 @@ async def list_recommendations(
action: Optional[str] = None, action: Optional[str] = None,
mode: Optional[str] = None, mode: Optional[str] = None,
since: Optional[str] = None, since: Optional[str] = None,
min_confidence: Optional[float] = Query(default=None, ge=0.0, le=1.0),
limit: int = Query(default=50, le=200), limit: int = Query(default=50, le=200),
offset: int = 0, offset: int = 0,
latest: bool = Query(default=True, description="Return only the latest recommendation per ticker"), latest: bool = Query(default=True, description="Return only the latest recommendation per ticker"),
@@ -557,6 +559,9 @@ async def list_recommendations(
By default (latest=true), returns only the most recent recommendation By default (latest=true), returns only the most recent recommendation
per ticker to avoid showing duplicate/stale entries. Set latest=false per ticker to avoid showing duplicate/stale entries. Set latest=false
to see the full history. to see the full history.
min_confidence filters to recommendations at or above the given threshold,
useful for showing only recs that would pass a specific risk tier gate.
""" """
conditions: list[str] = [] conditions: list[str] = []
params: list[Any] = [] params: list[Any] = []
@@ -578,6 +583,10 @@ async def list_recommendations(
conditions.append(f"r.generated_at >= ${idx}::timestamptz") conditions.append(f"r.generated_at >= ${idx}::timestamptz")
params.append(since) params.append(since)
idx += 1 idx += 1
if min_confidence is not None:
conditions.append(f"r.confidence >= ${idx}")
params.append(min_confidence)
idx += 1
where = ("WHERE " + " AND ".join(conditions)) if conditions else "" where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
@@ -2002,7 +2011,6 @@ async def pg_query(body: dict[str, Any]):
# Safety: only allow SELECT statements # Safety: only allow SELECT statements
# Strip SQL comments (-- and /* */) and whitespace before checking # Strip SQL comments (-- and /* */) and whitespace before checking
import re
stripped = re.sub(r'--[^\n]*', '', sql) # remove -- comments stripped = re.sub(r'--[^\n]*', '', sql) # remove -- comments
stripped = re.sub(r'/\*.*?\*/', '', stripped, flags=re.DOTALL) # remove /* */ comments stripped = re.sub(r'/\*.*?\*/', '', stripped, flags=re.DOTALL) # remove /* */ comments
stripped = stripped.strip() stripped = stripped.strip()
@@ -2678,6 +2686,70 @@ class AgentCreateBody(BaseModel):
max_retries: int = 2 max_retries: int = 2
# ---------------------------------------------------------------------------
# Variant Pydantic Models (Requirement 2, 3)
# ---------------------------------------------------------------------------
class VariantCreateBody(BaseModel):
variant_name: str
variant_slug: str | None = None
description: str = ""
model_provider: str = "ollama"
model_name: str
system_prompt: str = ""
user_prompt_template: str = ""
prompt_version: str = ""
temperature: float = 0.0
max_tokens: int = 32768
context_window: int = 0
input_token_limit: int = 0
token_budget: int = 0
timeout_seconds: int = 120
max_retries: int = 2
class VariantUpdateBody(BaseModel):
variant_name: str | None = None
description: str | None = None
model_provider: str | None = None
model_name: str | None = None
system_prompt: str | None = None
user_prompt_template: str | None = None
prompt_version: str | None = None
temperature: float | None = None
max_tokens: int | None = None
context_window: int | None = None
input_token_limit: int | None = None
token_budget: int | None = None
timeout_seconds: int | None = None
max_retries: int | None = None
class VariantCloneBody(BaseModel):
variant_name: str
variant_slug: str | None = None
description: str | None = None
model_provider: str | None = None
model_name: str | None = None
system_prompt: str | None = None
user_prompt_template: str | None = None
prompt_version: str | None = None
temperature: float | None = None
max_tokens: int | None = None
context_window: int | None = None
input_token_limit: int | None = None
token_budget: int | None = None
timeout_seconds: int | None = None
max_retries: int | None = None
def _slugify(name: str) -> str:
"""Generate a URL-safe slug from a variant name."""
slug = re.sub(r"[^a-z0-9]+", "-", name.lower())
return slug.strip("-")
@app.get("/api/agents") @app.get("/api/agents")
async def list_agents(active_only: bool = False): async def list_agents(active_only: bool = False):
"""List all AI agent configurations.""" """List all AI agent configurations."""
@@ -2827,3 +2899,384 @@ async def get_agent_performance_history(
agent_id, hours, agent_id, hours,
) )
return [_row_to_dict(r) for r in rows] return [_row_to_dict(r) for r in rows]
# ---------------------------------------------------------------------------
# Agent Variants (Requirements 2, 3, 4, 6, 10)
# ---------------------------------------------------------------------------
@app.get("/api/agents/{agent_id}/variants")
async def list_variants(agent_id: str):
"""List all variants for an agent, ordered by created_at ascending.
Requirement 3.1
"""
rows = await pool.fetch(
"""SELECT id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at
FROM agent_variants
WHERE agent_id = $1
ORDER BY created_at ASC""",
agent_id,
)
return [_row_to_dict(r) for r in rows]
@app.get("/api/agents/{agent_id}/variants/{variant_id}")
async def get_variant(agent_id: str, variant_id: str):
"""Get a single variant. Returns 404 if not found or agent mismatch.
Requirement 3.2
"""
row = await pool.fetchrow(
"""SELECT id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at
FROM agent_variants
WHERE id = $1 AND agent_id = $2""",
variant_id, agent_id,
)
if not row:
raise HTTPException(404, "Variant not found")
return _row_to_dict(row)
@app.post("/api/agents/{agent_id}/variants", status_code=201)
async def create_variant(agent_id: str, body: VariantCreateBody):
"""Create a new variant for an agent.
Auto-generates slug from variant_name if not provided.
Returns 409 on duplicate slug.
Requirement 3
"""
slug = body.variant_slug or _slugify(body.variant_name)
try:
row = await pool.fetchrow(
"""INSERT INTO agent_variants (
agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
agent_id, body.variant_name, slug, body.description,
body.model_provider, body.model_name, body.system_prompt,
body.user_prompt_template, body.prompt_version, body.temperature,
body.max_tokens, body.context_window, body.input_token_limit,
body.token_budget, body.timeout_seconds, body.max_retries,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Variant slug '{slug}' already exists for this agent")
return _row_to_dict(row)
@app.put("/api/agents/{agent_id}/variants/{variant_id}")
async def update_variant(agent_id: str, variant_id: str, body: VariantUpdateBody):
"""Partial update a variant. Sets updated_at = NOW().
Requirement 3.4
"""
updates: list[str] = []
params: list[Any] = []
idx = 1
for field_name, value in body.model_dump(exclude_none=True).items():
updates.append(f"{field_name} = ${idx}")
params.append(value)
idx += 1
if not updates:
raise HTTPException(400, "No fields to update")
updates.append("updated_at = NOW()")
set_clause = ", ".join(updates)
params.append(variant_id)
params.append(agent_id)
row = await pool.fetchrow(
f"""UPDATE agent_variants SET {set_clause}
WHERE id = ${idx} AND agent_id = ${idx + 1}
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
*params,
)
if not row:
raise HTTPException(404, "Variant not found")
return _row_to_dict(row)
@app.delete("/api/agents/{agent_id}/variants/{variant_id}")
async def delete_variant(agent_id: str, variant_id: str):
"""Delete a variant. Returns 400 if variant is currently active.
Requirement 3.5, 3.6
"""
row = await pool.fetchrow(
"SELECT is_active FROM agent_variants WHERE id = $1 AND agent_id = $2",
variant_id, agent_id,
)
if not row:
raise HTTPException(404, "Variant not found")
if row["is_active"]:
raise HTTPException(400, "Cannot delete active variant — deactivate it first")
await pool.execute(
"DELETE FROM agent_variants WHERE id = $1 AND agent_id = $2",
variant_id, agent_id,
)
return {"deleted": True}
# ---------------------------------------------------------------------------
# Clone Endpoints (Requirement 2)
# ---------------------------------------------------------------------------
@app.post("/api/agents/{agent_id}/clone", status_code=201)
async def clone_agent_as_variant(agent_id: str, body: VariantCloneBody):
"""Clone an agent's configuration as a new variant.
Copies the agent's model/prompt/parameter fields into a new variant,
with optional overrides from the request body.
Requirement 2.1, 2.3, 2.4, 2.5, 2.6
"""
agent = await pool.fetchrow(
"""SELECT model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, timeout_seconds, max_retries
FROM ai_agents WHERE id = $1""",
agent_id,
)
if not agent:
raise HTTPException(404, "Agent not found")
slug = body.variant_slug or _slugify(body.variant_name)
description = body.description if body.description is not None else ""
model_provider = body.model_provider if body.model_provider is not None else agent["model_provider"]
model_name = body.model_name if body.model_name is not None else agent["model_name"]
system_prompt = body.system_prompt if body.system_prompt is not None else agent["system_prompt"]
user_prompt_template = body.user_prompt_template if body.user_prompt_template is not None else agent["user_prompt_template"]
prompt_version = body.prompt_version if body.prompt_version is not None else agent["prompt_version"]
temperature = body.temperature if body.temperature is not None else agent["temperature"]
max_tokens = body.max_tokens if body.max_tokens is not None else agent["max_tokens"]
timeout_seconds = body.timeout_seconds if body.timeout_seconds is not None else agent["timeout_seconds"]
max_retries = body.max_retries if body.max_retries is not None else agent["max_retries"]
# ai_agents table doesn't have these columns — default to 0
context_window = body.context_window if body.context_window is not None else 0
input_token_limit = body.input_token_limit if body.input_token_limit is not None else 0
token_budget = body.token_budget if body.token_budget is not None else 0
try:
row = await pool.fetchrow(
"""INSERT INTO agent_variants (
agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
agent_id, body.variant_name, slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Variant slug '{slug}' already exists for this agent")
return _row_to_dict(row)
@app.post("/api/agents/{agent_id}/variants/{variant_id}/clone", status_code=201)
async def clone_variant(agent_id: str, variant_id: str, body: VariantCloneBody):
"""Clone an existing variant as a new variant under the same agent.
Copies all configuration fields from the source variant,
with optional overrides from the request body.
Requirement 2.2, 2.3, 2.4, 2.5, 2.6
"""
source = await pool.fetchrow(
"""SELECT model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
description
FROM agent_variants
WHERE id = $1 AND agent_id = $2""",
variant_id, agent_id,
)
if not source:
raise HTTPException(404, "Source variant not found")
slug = body.variant_slug or _slugify(body.variant_name)
description = body.description if body.description is not None else source["description"]
model_provider = body.model_provider if body.model_provider is not None else source["model_provider"]
model_name = body.model_name if body.model_name is not None else source["model_name"]
system_prompt = body.system_prompt if body.system_prompt is not None else source["system_prompt"]
user_prompt_template = body.user_prompt_template if body.user_prompt_template is not None else source["user_prompt_template"]
prompt_version = body.prompt_version if body.prompt_version is not None else source["prompt_version"]
temperature = body.temperature if body.temperature is not None else source["temperature"]
max_tokens = body.max_tokens if body.max_tokens is not None else source["max_tokens"]
context_window = body.context_window if body.context_window is not None else source["context_window"]
input_token_limit = body.input_token_limit if body.input_token_limit is not None else source["input_token_limit"]
token_budget = body.token_budget if body.token_budget is not None else source["token_budget"]
timeout_seconds = body.timeout_seconds if body.timeout_seconds is not None else source["timeout_seconds"]
max_retries = body.max_retries if body.max_retries is not None else source["max_retries"]
try:
row = await pool.fetchrow(
"""INSERT INTO agent_variants (
agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
agent_id, body.variant_name, slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Variant slug '{slug}' already exists for this agent")
return _row_to_dict(row)
# ---------------------------------------------------------------------------
# Activate / Deactivate Endpoints (Requirement 4)
# ---------------------------------------------------------------------------
@app.post("/api/agents/{agent_id}/variants/{variant_id}/activate")
async def activate_variant(agent_id: str, variant_id: str):
"""Set a variant as the active variant for its agent.
Within a single transaction: deactivate any currently active variant,
then activate the target variant.
Requirement 4.1, 4.5
"""
async with pool.acquire() as conn:
async with conn.transaction():
# Deactivate any currently active variant for this agent
await conn.execute(
"""UPDATE agent_variants SET is_active = FALSE, updated_at = NOW()
WHERE agent_id = $1 AND is_active = TRUE""",
agent_id,
)
# Activate the target variant
row = await conn.fetchrow(
"""UPDATE agent_variants SET is_active = TRUE, updated_at = NOW()
WHERE id = $1 AND agent_id = $2
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
variant_id, agent_id,
)
if not row:
raise HTTPException(404, "Variant not found")
return _row_to_dict(row)
@app.post("/api/agents/{agent_id}/variants/deactivate")
async def deactivate_variants(agent_id: str):
"""Deactivate the currently active variant for an agent.
The agent falls back to its base configuration.
Requirement 4.2
"""
await pool.execute(
"""UPDATE agent_variants SET is_active = FALSE, updated_at = NOW()
WHERE agent_id = $1 AND is_active = TRUE""",
agent_id,
)
return {"deactivated": True}
# ---------------------------------------------------------------------------
# Per-Variant Performance Endpoints (Requirement 6)
# ---------------------------------------------------------------------------
@app.get("/api/agents/{agent_id}/variants/{variant_id}/performance")
async def get_variant_performance(
agent_id: str,
variant_id: str,
hours: int = Query(default=24, le=720),
):
"""Aggregated performance metrics for a specific variant.
Requirement 6.3
"""
row = await pool.fetchrow(
"""SELECT
COUNT(*) AS total_invocations,
COUNT(*) FILTER (WHERE success) AS successes,
COUNT(*) FILTER (WHERE NOT success) AS failures,
ROUND(AVG(duration_ms)::numeric) AS avg_duration_ms,
ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY duration_ms)::numeric) AS p95_duration_ms,
ROUND(AVG(confidence)::numeric, 4) AS avg_confidence,
ROUND(AVG(retry_count)::numeric, 2) AS avg_retries,
SUM(input_tokens) AS total_input_tokens,
SUM(output_tokens) AS total_output_tokens
FROM agent_performance_log
WHERE agent_id = $1
AND variant_id = $2
AND recorded_at >= NOW() - make_interval(hours => $3)""",
agent_id, variant_id, hours,
)
d = _row_to_dict(row) if row else {}
total = int(d.get("total_invocations", 0) or 0)
successes = int(d.get("successes", 0) or 0)
d["success_rate"] = round(successes / total, 4) if total > 0 else None
return d
@app.get("/api/agents/{agent_id}/variants/{variant_id}/performance/history")
async def get_variant_performance_history(
agent_id: str,
variant_id: str,
hours: int = Query(default=24, le=720),
):
"""Hourly performance time-series for a specific variant.
Requirement 6.4
"""
rows = await pool.fetch(
"""SELECT
date_trunc('hour', recorded_at) AS hour,
COUNT(*) AS invocations,
COUNT(*) FILTER (WHERE success) AS successes,
ROUND(AVG(duration_ms)::numeric) AS avg_duration_ms,
ROUND(AVG(confidence)::numeric, 4) AS avg_confidence
FROM agent_performance_log
WHERE agent_id = $1
AND variant_id = $2
AND recorded_at >= NOW() - make_interval(hours => $3)
GROUP BY 1 ORDER BY 1""",
agent_id, variant_id, hours,
)
return [_row_to_dict(r) for r in rows]