feat: risk tier selector on Trading page + confidence filter on Recommendations

- Trading page: added conservative/moderate/aggressive selector that
  updates the trading engine config via PUT /api/trading/config
- Recommendations page: added risk tier dropdown that defaults to the
  engine's current tier and filters recs by the tier's min_confidence
- Backend: added min_confidence query param to GET /api/recommendations
- Risk tier thresholds: conservative ≥0.75, moderate ≥0.55, aggressive ≥0.40
This commit is contained in:
Celes Renata
2026-04-17 05:08:54 +00:00
parent 49e3955fab
commit 734bf001a7
4 changed files with 532 additions and 5 deletions
+454 -1
View File
@@ -15,6 +15,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import re
import time as _time
from contextlib import asynccontextmanager
from dataclasses import asdict
@@ -548,6 +549,7 @@ async def list_recommendations(
action: Optional[str] = None,
mode: Optional[str] = None,
since: Optional[str] = None,
min_confidence: Optional[float] = Query(default=None, ge=0.0, le=1.0),
limit: int = Query(default=50, le=200),
offset: int = 0,
latest: bool = Query(default=True, description="Return only the latest recommendation per ticker"),
@@ -557,6 +559,9 @@ async def list_recommendations(
By default (latest=true), returns only the most recent recommendation
per ticker to avoid showing duplicate/stale entries. Set latest=false
to see the full history.
min_confidence filters to recommendations at or above the given threshold,
useful for showing only recs that would pass a specific risk tier gate.
"""
conditions: list[str] = []
params: list[Any] = []
@@ -578,6 +583,10 @@ async def list_recommendations(
conditions.append(f"r.generated_at >= ${idx}::timestamptz")
params.append(since)
idx += 1
if min_confidence is not None:
conditions.append(f"r.confidence >= ${idx}")
params.append(min_confidence)
idx += 1
where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
@@ -2002,7 +2011,6 @@ async def pg_query(body: dict[str, Any]):
# Safety: only allow SELECT statements
# Strip SQL comments (-- and /* */) and whitespace before checking
import re
stripped = re.sub(r'--[^\n]*', '', sql) # remove -- comments
stripped = re.sub(r'/\*.*?\*/', '', stripped, flags=re.DOTALL) # remove /* */ comments
stripped = stripped.strip()
@@ -2678,6 +2686,70 @@ class AgentCreateBody(BaseModel):
max_retries: int = 2
# ---------------------------------------------------------------------------
# Variant Pydantic Models (Requirement 2, 3)
# ---------------------------------------------------------------------------
class VariantCreateBody(BaseModel):
variant_name: str
variant_slug: str | None = None
description: str = ""
model_provider: str = "ollama"
model_name: str
system_prompt: str = ""
user_prompt_template: str = ""
prompt_version: str = ""
temperature: float = 0.0
max_tokens: int = 32768
context_window: int = 0
input_token_limit: int = 0
token_budget: int = 0
timeout_seconds: int = 120
max_retries: int = 2
class VariantUpdateBody(BaseModel):
variant_name: str | None = None
description: str | None = None
model_provider: str | None = None
model_name: str | None = None
system_prompt: str | None = None
user_prompt_template: str | None = None
prompt_version: str | None = None
temperature: float | None = None
max_tokens: int | None = None
context_window: int | None = None
input_token_limit: int | None = None
token_budget: int | None = None
timeout_seconds: int | None = None
max_retries: int | None = None
class VariantCloneBody(BaseModel):
variant_name: str
variant_slug: str | None = None
description: str | None = None
model_provider: str | None = None
model_name: str | None = None
system_prompt: str | None = None
user_prompt_template: str | None = None
prompt_version: str | None = None
temperature: float | None = None
max_tokens: int | None = None
context_window: int | None = None
input_token_limit: int | None = None
token_budget: int | None = None
timeout_seconds: int | None = None
max_retries: int | None = None
def _slugify(name: str) -> str:
"""Generate a URL-safe slug from a variant name."""
slug = re.sub(r"[^a-z0-9]+", "-", name.lower())
return slug.strip("-")
@app.get("/api/agents")
async def list_agents(active_only: bool = False):
"""List all AI agent configurations."""
@@ -2827,3 +2899,384 @@ async def get_agent_performance_history(
agent_id, hours,
)
return [_row_to_dict(r) for r in rows]
# ---------------------------------------------------------------------------
# Agent Variants (Requirements 2, 3, 4, 6, 10)
# ---------------------------------------------------------------------------
@app.get("/api/agents/{agent_id}/variants")
async def list_variants(agent_id: str):
"""List all variants for an agent, ordered by created_at ascending.
Requirement 3.1
"""
rows = await pool.fetch(
"""SELECT id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at
FROM agent_variants
WHERE agent_id = $1
ORDER BY created_at ASC""",
agent_id,
)
return [_row_to_dict(r) for r in rows]
@app.get("/api/agents/{agent_id}/variants/{variant_id}")
async def get_variant(agent_id: str, variant_id: str):
"""Get a single variant. Returns 404 if not found or agent mismatch.
Requirement 3.2
"""
row = await pool.fetchrow(
"""SELECT id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at
FROM agent_variants
WHERE id = $1 AND agent_id = $2""",
variant_id, agent_id,
)
if not row:
raise HTTPException(404, "Variant not found")
return _row_to_dict(row)
@app.post("/api/agents/{agent_id}/variants", status_code=201)
async def create_variant(agent_id: str, body: VariantCreateBody):
"""Create a new variant for an agent.
Auto-generates slug from variant_name if not provided.
Returns 409 on duplicate slug.
Requirement 3
"""
slug = body.variant_slug or _slugify(body.variant_name)
try:
row = await pool.fetchrow(
"""INSERT INTO agent_variants (
agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
agent_id, body.variant_name, slug, body.description,
body.model_provider, body.model_name, body.system_prompt,
body.user_prompt_template, body.prompt_version, body.temperature,
body.max_tokens, body.context_window, body.input_token_limit,
body.token_budget, body.timeout_seconds, body.max_retries,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Variant slug '{slug}' already exists for this agent")
return _row_to_dict(row)
@app.put("/api/agents/{agent_id}/variants/{variant_id}")
async def update_variant(agent_id: str, variant_id: str, body: VariantUpdateBody):
"""Partial update a variant. Sets updated_at = NOW().
Requirement 3.4
"""
updates: list[str] = []
params: list[Any] = []
idx = 1
for field_name, value in body.model_dump(exclude_none=True).items():
updates.append(f"{field_name} = ${idx}")
params.append(value)
idx += 1
if not updates:
raise HTTPException(400, "No fields to update")
updates.append("updated_at = NOW()")
set_clause = ", ".join(updates)
params.append(variant_id)
params.append(agent_id)
row = await pool.fetchrow(
f"""UPDATE agent_variants SET {set_clause}
WHERE id = ${idx} AND agent_id = ${idx + 1}
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
*params,
)
if not row:
raise HTTPException(404, "Variant not found")
return _row_to_dict(row)
@app.delete("/api/agents/{agent_id}/variants/{variant_id}")
async def delete_variant(agent_id: str, variant_id: str):
"""Delete a variant. Returns 400 if variant is currently active.
Requirement 3.5, 3.6
"""
row = await pool.fetchrow(
"SELECT is_active FROM agent_variants WHERE id = $1 AND agent_id = $2",
variant_id, agent_id,
)
if not row:
raise HTTPException(404, "Variant not found")
if row["is_active"]:
raise HTTPException(400, "Cannot delete active variant — deactivate it first")
await pool.execute(
"DELETE FROM agent_variants WHERE id = $1 AND agent_id = $2",
variant_id, agent_id,
)
return {"deleted": True}
# ---------------------------------------------------------------------------
# Clone Endpoints (Requirement 2)
# ---------------------------------------------------------------------------
@app.post("/api/agents/{agent_id}/clone", status_code=201)
async def clone_agent_as_variant(agent_id: str, body: VariantCloneBody):
"""Clone an agent's configuration as a new variant.
Copies the agent's model/prompt/parameter fields into a new variant,
with optional overrides from the request body.
Requirement 2.1, 2.3, 2.4, 2.5, 2.6
"""
agent = await pool.fetchrow(
"""SELECT model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, timeout_seconds, max_retries
FROM ai_agents WHERE id = $1""",
agent_id,
)
if not agent:
raise HTTPException(404, "Agent not found")
slug = body.variant_slug or _slugify(body.variant_name)
description = body.description if body.description is not None else ""
model_provider = body.model_provider if body.model_provider is not None else agent["model_provider"]
model_name = body.model_name if body.model_name is not None else agent["model_name"]
system_prompt = body.system_prompt if body.system_prompt is not None else agent["system_prompt"]
user_prompt_template = body.user_prompt_template if body.user_prompt_template is not None else agent["user_prompt_template"]
prompt_version = body.prompt_version if body.prompt_version is not None else agent["prompt_version"]
temperature = body.temperature if body.temperature is not None else agent["temperature"]
max_tokens = body.max_tokens if body.max_tokens is not None else agent["max_tokens"]
timeout_seconds = body.timeout_seconds if body.timeout_seconds is not None else agent["timeout_seconds"]
max_retries = body.max_retries if body.max_retries is not None else agent["max_retries"]
# ai_agents table doesn't have these columns — default to 0
context_window = body.context_window if body.context_window is not None else 0
input_token_limit = body.input_token_limit if body.input_token_limit is not None else 0
token_budget = body.token_budget if body.token_budget is not None else 0
try:
row = await pool.fetchrow(
"""INSERT INTO agent_variants (
agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
agent_id, body.variant_name, slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Variant slug '{slug}' already exists for this agent")
return _row_to_dict(row)
@app.post("/api/agents/{agent_id}/variants/{variant_id}/clone", status_code=201)
async def clone_variant(agent_id: str, variant_id: str, body: VariantCloneBody):
"""Clone an existing variant as a new variant under the same agent.
Copies all configuration fields from the source variant,
with optional overrides from the request body.
Requirement 2.2, 2.3, 2.4, 2.5, 2.6
"""
source = await pool.fetchrow(
"""SELECT model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
description
FROM agent_variants
WHERE id = $1 AND agent_id = $2""",
variant_id, agent_id,
)
if not source:
raise HTTPException(404, "Source variant not found")
slug = body.variant_slug or _slugify(body.variant_name)
description = body.description if body.description is not None else source["description"]
model_provider = body.model_provider if body.model_provider is not None else source["model_provider"]
model_name = body.model_name if body.model_name is not None else source["model_name"]
system_prompt = body.system_prompt if body.system_prompt is not None else source["system_prompt"]
user_prompt_template = body.user_prompt_template if body.user_prompt_template is not None else source["user_prompt_template"]
prompt_version = body.prompt_version if body.prompt_version is not None else source["prompt_version"]
temperature = body.temperature if body.temperature is not None else source["temperature"]
max_tokens = body.max_tokens if body.max_tokens is not None else source["max_tokens"]
context_window = body.context_window if body.context_window is not None else source["context_window"]
input_token_limit = body.input_token_limit if body.input_token_limit is not None else source["input_token_limit"]
token_budget = body.token_budget if body.token_budget is not None else source["token_budget"]
timeout_seconds = body.timeout_seconds if body.timeout_seconds is not None else source["timeout_seconds"]
max_retries = body.max_retries if body.max_retries is not None else source["max_retries"]
try:
row = await pool.fetchrow(
"""INSERT INTO agent_variants (
agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
agent_id, body.variant_name, slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Variant slug '{slug}' already exists for this agent")
return _row_to_dict(row)
# ---------------------------------------------------------------------------
# Activate / Deactivate Endpoints (Requirement 4)
# ---------------------------------------------------------------------------
@app.post("/api/agents/{agent_id}/variants/{variant_id}/activate")
async def activate_variant(agent_id: str, variant_id: str):
"""Set a variant as the active variant for its agent.
Within a single transaction: deactivate any currently active variant,
then activate the target variant.
Requirement 4.1, 4.5
"""
async with pool.acquire() as conn:
async with conn.transaction():
# Deactivate any currently active variant for this agent
await conn.execute(
"""UPDATE agent_variants SET is_active = FALSE, updated_at = NOW()
WHERE agent_id = $1 AND is_active = TRUE""",
agent_id,
)
# Activate the target variant
row = await conn.fetchrow(
"""UPDATE agent_variants SET is_active = TRUE, updated_at = NOW()
WHERE id = $1 AND agent_id = $2
RETURNING id, agent_id, variant_name, variant_slug, description,
model_provider, model_name, system_prompt, user_prompt_template,
prompt_version, temperature, max_tokens, context_window,
input_token_limit, token_budget, timeout_seconds, max_retries,
is_active, created_at, updated_at""",
variant_id, agent_id,
)
if not row:
raise HTTPException(404, "Variant not found")
return _row_to_dict(row)
@app.post("/api/agents/{agent_id}/variants/deactivate")
async def deactivate_variants(agent_id: str):
"""Deactivate the currently active variant for an agent.
The agent falls back to its base configuration.
Requirement 4.2
"""
await pool.execute(
"""UPDATE agent_variants SET is_active = FALSE, updated_at = NOW()
WHERE agent_id = $1 AND is_active = TRUE""",
agent_id,
)
return {"deactivated": True}
# ---------------------------------------------------------------------------
# Per-Variant Performance Endpoints (Requirement 6)
# ---------------------------------------------------------------------------
@app.get("/api/agents/{agent_id}/variants/{variant_id}/performance")
async def get_variant_performance(
agent_id: str,
variant_id: str,
hours: int = Query(default=24, le=720),
):
"""Aggregated performance metrics for a specific variant.
Requirement 6.3
"""
row = await pool.fetchrow(
"""SELECT
COUNT(*) AS total_invocations,
COUNT(*) FILTER (WHERE success) AS successes,
COUNT(*) FILTER (WHERE NOT success) AS failures,
ROUND(AVG(duration_ms)::numeric) AS avg_duration_ms,
ROUND(PERCENTILE_CONT(0.95) WITHIN GROUP (ORDER BY duration_ms)::numeric) AS p95_duration_ms,
ROUND(AVG(confidence)::numeric, 4) AS avg_confidence,
ROUND(AVG(retry_count)::numeric, 2) AS avg_retries,
SUM(input_tokens) AS total_input_tokens,
SUM(output_tokens) AS total_output_tokens
FROM agent_performance_log
WHERE agent_id = $1
AND variant_id = $2
AND recorded_at >= NOW() - make_interval(hours => $3)""",
agent_id, variant_id, hours,
)
d = _row_to_dict(row) if row else {}
total = int(d.get("total_invocations", 0) or 0)
successes = int(d.get("successes", 0) or 0)
d["success_rate"] = round(successes / total, 4) if total > 0 else None
return d
@app.get("/api/agents/{agent_id}/variants/{variant_id}/performance/history")
async def get_variant_performance_history(
agent_id: str,
variant_id: str,
hours: int = Query(default=24, le=720),
):
"""Hourly performance time-series for a specific variant.
Requirement 6.4
"""
rows = await pool.fetch(
"""SELECT
date_trunc('hour', recorded_at) AS hour,
COUNT(*) AS invocations,
COUNT(*) FILTER (WHERE success) AS successes,
ROUND(AVG(duration_ms)::numeric) AS avg_duration_ms,
ROUND(AVG(confidence)::numeric, 4) AS avg_confidence
FROM agent_performance_log
WHERE agent_id = $1
AND variant_id = $2
AND recorded_at >= NOW() - make_interval(hours => $3)
GROUP BY 1 ORDER BY 1""",
agent_id, variant_id, hours,
)
return [_row_to_dict(r) for r in rows]