Files
stonks-oracle/services/symbol_registry/app.py
T
Celes Renata 5acb2fb43e fix: resolve 6 integration test failures
1. patterns endpoint: fix query referencing non-existent column
   di.catalyst_type → dir.catalyst_type (column is on document_impact_records)
2. lockouts seed: use relative timestamps (now + 7d) so active lockout
   is always in the future regardless of when tests run
3. create_agent: make slug optional with auto-generation from name
4. create_source: json.dumps(config) + ::jsonb cast for asyncpg JSONB compat
5. approval_expiry: return count as int (len(expired)) not the list itself
6. metrics_consistency: fix test assertion to match API contract
   (total >= active + reserve, not total == active + reserve + unrealized)
2026-04-20 04:30:13 +00:00

279 lines
9.4 KiB
Python

"""Symbol Registry API - FastAPI application."""
import re
import uuid
from contextlib import asynccontextmanager
from typing import Any, List, Optional
from urllib.parse import urlparse
import asyncpg
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, field_validator
from services.shared.config import load_config
from services.shared.db import get_pg_pool
from services.shared.logging import setup_logging
from services.symbol_registry.competitor_inference import router as inference_router
from services.symbol_registry.competitors import router as competitors_router
from services.symbol_registry.exposure import router as exposure_router
config = load_config()
pool: Optional[asyncpg.Pool] = None
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
"""Convert asyncpg Record to dict with UUID→str coercion."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
return d
@asynccontextmanager
async def lifespan(app: FastAPI):
global pool
setup_logging("symbol_registry", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
yield
await pool.close()
app = FastAPI(title="Stonks Oracle - Symbol Registry", lifespan=lifespan)
app.include_router(exposure_router)
app.include_router(competitors_router)
app.include_router(inference_router)
@app.get("/health")
async def health():
try:
await pool.fetchval("SELECT 1")
return {"status": "ok"}
except Exception:
raise HTTPException(503, "Database unavailable")
TICKER_PATTERN = re.compile(r"^[A-Z]{1,10}$")
VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"}
VALID_ACCESS_POLICIES = {"internal", "public", "restricted"}
# --- Request/Response Models ---
class CompanyCreate(BaseModel):
ticker: str
legal_name: str
exchange: Optional[str] = None
sector: Optional[str] = None
industry: Optional[str] = None
market_cap_bucket: Optional[str] = None
@field_validator("ticker")
@classmethod
def validate_ticker(cls, v: str) -> str:
v = v.upper().strip()
if not TICKER_PATTERN.match(v):
raise ValueError(f"Ticker must be 1-10 uppercase letters, got: {v}")
return v
class CompanyResponse(BaseModel):
id: str
ticker: str
legal_name: str
exchange: Optional[str]
sector: Optional[str]
industry: Optional[str]
market_cap_bucket: Optional[str]
active: bool
class AliasCreate(BaseModel):
alias: str
alias_type: str = "brand"
class WatchlistCreate(BaseModel):
name: str
description: Optional[str] = None
class SourceCreate(BaseModel):
source_type: str
source_name: str
config: dict = {}
credibility_score: float = 0.5
retention_days: int = 365
access_policy: str = "internal"
@field_validator("source_type")
@classmethod
def validate_source_type(cls, v: str) -> str:
if v not in VALID_SOURCE_TYPES:
raise ValueError(f"source_type must be one of {VALID_SOURCE_TYPES}")
return v
@field_validator("access_policy")
@classmethod
def validate_access_policy(cls, v: str) -> str:
if v not in VALID_ACCESS_POLICIES:
raise ValueError(f"access_policy must be one of {VALID_ACCESS_POLICIES}")
return v
@field_validator("config")
@classmethod
def validate_config_urls(cls, v: dict) -> dict:
"""Validate any URL fields in the config dict."""
for key in ("base_url", "endpoint", "url"):
if key in v and v[key]:
parsed = urlparse(str(v[key]))
if key == "base_url" and parsed.scheme not in ("http", "https"):
raise ValueError(f"config.{key} must be a valid HTTP(S) URL")
return v
VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"}
# --- Company Endpoints ---
@app.post("/companies", response_model=CompanyResponse, status_code=201)
async def create_company(body: CompanyCreate):
try:
row = await pool.fetchrow(
"""INSERT INTO companies (ticker, legal_name, exchange, sector, industry, market_cap_bucket)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""",
body.ticker.upper(), body.legal_name, body.exchange, body.sector,
body.industry, body.market_cap_bucket,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Company {body.ticker} on {body.exchange} already exists")
return _row_dict(row)
@app.get("/companies", response_model=List[CompanyResponse])
async def list_companies(active: bool = True):
rows = await pool.fetch(
"SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE active = $1 ORDER BY ticker",
active,
)
return [_row_dict(r) for r in rows]
@app.get("/companies/{company_id}", response_model=CompanyResponse)
async def get_company(company_id: str):
row = await pool.fetchrow(
"SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE id = $1",
company_id,
)
if not row:
raise HTTPException(404, "Company not found")
return _row_dict(row)
@app.put("/companies/{company_id}", response_model=CompanyResponse)
async def update_company(company_id: str, body: CompanyCreate):
row = await pool.fetchrow(
"""UPDATE companies SET ticker=$2, legal_name=$3, exchange=$4, sector=$5, industry=$6, market_cap_bucket=$7, updated_at=NOW()
WHERE id=$1
RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""",
company_id, body.ticker.upper(), body.legal_name, body.exchange,
body.sector, body.industry, body.market_cap_bucket,
)
if not row:
raise HTTPException(404, "Company not found")
return _row_dict(row)
# --- Alias Endpoints ---
@app.post("/companies/{company_id}/aliases", status_code=201)
async def add_alias(company_id: str, body: AliasCreate):
row = await pool.fetchrow(
"INSERT INTO company_aliases (company_id, alias, alias_type) VALUES ($1, $2, $3) RETURNING id, alias, alias_type",
company_id, body.alias, body.alias_type,
)
return _row_dict(row)
@app.get("/companies/{company_id}/aliases")
async def list_aliases(company_id: str):
rows = await pool.fetch(
"SELECT id, alias, alias_type FROM company_aliases WHERE company_id = $1",
company_id,
)
return [_row_dict(r) for r in rows]
# --- Watchlist Endpoints ---
@app.post("/watchlists", status_code=201)
async def create_watchlist(body: WatchlistCreate):
try:
row = await pool.fetchrow(
"INSERT INTO watchlists (name, description) VALUES ($1, $2) RETURNING id, name, description, active",
body.name, body.description,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Watchlist '{body.name}' already exists")
return _row_dict(row)
@app.get("/watchlists")
async def list_watchlists():
rows = await pool.fetch("SELECT id, name, description, active FROM watchlists ORDER BY name")
return [_row_dict(r) for r in rows]
@app.post("/watchlists/{watchlist_id}/members/{company_id}", status_code=201)
async def add_watchlist_member(watchlist_id: str, company_id: str):
try:
await pool.execute(
"INSERT INTO watchlist_members (watchlist_id, company_id) VALUES ($1, $2)",
watchlist_id, company_id,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, "Already a member")
except asyncpg.ForeignKeyViolationError:
raise HTTPException(404, "Watchlist or company not found")
return {"status": "added"}
@app.get("/watchlists/{watchlist_id}/members")
async def list_watchlist_members(watchlist_id: str):
rows = await pool.fetch(
"""SELECT c.id, c.ticker, c.legal_name, c.exchange, c.sector, c.industry, c.market_cap_bucket, c.active
FROM companies c JOIN watchlist_members wm ON c.id = wm.company_id
WHERE wm.watchlist_id = $1 ORDER BY c.ticker""",
watchlist_id,
)
return [_row_dict(r) for r in rows]
# --- Source Endpoints ---
@app.post("/companies/{company_id}/sources", status_code=201)
async def add_source(company_id: str, body: SourceCreate):
# Verify company exists
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
if not exists:
raise HTTPException(404, "Company not found")
import json as _json
row = await pool.fetchrow(
"""INSERT INTO sources (company_id, source_type, source_name, config, credibility_score, retention_days, access_policy)
VALUES ($1, $2, $3, $4::jsonb, $5, $6, $7)
RETURNING id, source_type, source_name, credibility_score, active""",
company_id, body.source_type, body.source_name,
_json.dumps(body.config), body.credibility_score, body.retention_days, body.access_policy,
)
return _row_dict(row)
@app.get("/companies/{company_id}/sources")
async def list_sources(company_id: str):
rows = await pool.fetch(
"SELECT id, source_type, source_name, config, credibility_score, retention_days, access_policy, active FROM sources WHERE company_id = $1",
company_id,
)
return [_row_dict(r) for r in rows]