phase 16: fix UUID serialization in symbol registry responses

This commit is contained in:
Celes Renata
2026-04-11 19:22:13 -07:00
parent 6f5b2231a2
commit 5758a704ec
+22 -12
View File
@@ -1,7 +1,8 @@
"""Symbol Registry API - FastAPI application.""" """Symbol Registry API - FastAPI application."""
import re import re
import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import List, Optional from typing import Any, List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import asyncpg import asyncpg
@@ -16,6 +17,15 @@ config = load_config()
pool: Optional[asyncpg.Pool] = None pool: Optional[asyncpg.Pool] = None
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
"""Convert asyncpg Record to dict with UUID→str coercion."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
return d
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
global pool global pool
@@ -132,7 +142,7 @@ async def create_company(body: CompanyCreate):
) )
except asyncpg.UniqueViolationError: except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Company {body.ticker} on {body.exchange} already exists") raise HTTPException(409, f"Company {body.ticker} on {body.exchange} already exists")
return dict(row) return _row_dict(row)
@app.get("/companies", response_model=List[CompanyResponse]) @app.get("/companies", response_model=List[CompanyResponse])
@@ -141,7 +151,7 @@ async def list_companies(active: bool = True):
"SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE active = $1 ORDER BY ticker", "SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE active = $1 ORDER BY ticker",
active, active,
) )
return [dict(r) for r in rows] return [_row_dict(r) for r in rows]
@app.get("/companies/{company_id}", response_model=CompanyResponse) @app.get("/companies/{company_id}", response_model=CompanyResponse)
@@ -152,7 +162,7 @@ async def get_company(company_id: str):
) )
if not row: if not row:
raise HTTPException(404, "Company not found") raise HTTPException(404, "Company not found")
return dict(row) return _row_dict(row)
@app.put("/companies/{company_id}", response_model=CompanyResponse) @app.put("/companies/{company_id}", response_model=CompanyResponse)
@@ -166,7 +176,7 @@ async def update_company(company_id: str, body: CompanyCreate):
) )
if not row: if not row:
raise HTTPException(404, "Company not found") raise HTTPException(404, "Company not found")
return dict(row) return _row_dict(row)
# --- Alias Endpoints --- # --- Alias Endpoints ---
@@ -177,7 +187,7 @@ async def add_alias(company_id: str, body: AliasCreate):
"INSERT INTO company_aliases (company_id, alias, alias_type) VALUES ($1, $2, $3) RETURNING id, alias, alias_type", "INSERT INTO company_aliases (company_id, alias, alias_type) VALUES ($1, $2, $3) RETURNING id, alias, alias_type",
company_id, body.alias, body.alias_type, company_id, body.alias, body.alias_type,
) )
return dict(row) return _row_dict(row)
@app.get("/companies/{company_id}/aliases") @app.get("/companies/{company_id}/aliases")
@@ -186,7 +196,7 @@ async def list_aliases(company_id: str):
"SELECT id, alias, alias_type FROM company_aliases WHERE company_id = $1", "SELECT id, alias, alias_type FROM company_aliases WHERE company_id = $1",
company_id, company_id,
) )
return [dict(r) for r in rows] return [_row_dict(r) for r in rows]
# --- Watchlist Endpoints --- # --- Watchlist Endpoints ---
@@ -200,13 +210,13 @@ async def create_watchlist(body: WatchlistCreate):
) )
except asyncpg.UniqueViolationError: except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Watchlist '{body.name}' already exists") raise HTTPException(409, f"Watchlist '{body.name}' already exists")
return dict(row) return _row_dict(row)
@app.get("/watchlists") @app.get("/watchlists")
async def list_watchlists(): async def list_watchlists():
rows = await pool.fetch("SELECT id, name, description, active FROM watchlists ORDER BY name") rows = await pool.fetch("SELECT id, name, description, active FROM watchlists ORDER BY name")
return [dict(r) for r in rows] return [_row_dict(r) for r in rows]
@app.post("/watchlists/{watchlist_id}/members/{company_id}", status_code=201) @app.post("/watchlists/{watchlist_id}/members/{company_id}", status_code=201)
@@ -231,7 +241,7 @@ async def list_watchlist_members(watchlist_id: str):
WHERE wm.watchlist_id = $1 ORDER BY c.ticker""", WHERE wm.watchlist_id = $1 ORDER BY c.ticker""",
watchlist_id, watchlist_id,
) )
return [dict(r) for r in rows] return [_row_dict(r) for r in rows]
# --- Source Endpoints --- # --- Source Endpoints ---
@@ -249,7 +259,7 @@ async def add_source(company_id: str, body: SourceCreate):
company_id, body.source_type, body.source_name, company_id, body.source_type, body.source_name,
body.config, body.credibility_score, body.retention_days, body.access_policy, body.config, body.credibility_score, body.retention_days, body.access_policy,
) )
return dict(row) return _row_dict(row)
@app.get("/companies/{company_id}/sources") @app.get("/companies/{company_id}/sources")
@@ -258,4 +268,4 @@ async def list_sources(company_id: str):
"SELECT id, source_type, source_name, config, credibility_score, retention_days, access_policy, active FROM sources WHERE company_id = $1", "SELECT id, source_type, source_name, config, credibility_score, retention_days, access_policy, active FROM sources WHERE company_id = $1",
company_id, company_id,
) )
return [dict(r) for r in rows] return [_row_dict(r) for r in rows]