"""Symbol Registry API - FastAPI application.""" from contextlib import asynccontextmanager from typing import List, Optional import asyncpg from fastapi import FastAPI, HTTPException from pydantic import BaseModel from services.shared.config import load_config from services.shared.db import get_pg_pool config = load_config() pool: Optional[asyncpg.Pool] = None @asynccontextmanager async def lifespan(app: FastAPI): global pool pool = await get_pg_pool(config) yield await pool.close() app = FastAPI(title="Stonks Oracle - Symbol Registry", lifespan=lifespan) # --- Request/Response Models --- class CompanyCreate(BaseModel): ticker: str legal_name: str exchange: Optional[str] = None sector: Optional[str] = None industry: Optional[str] = None market_cap_bucket: Optional[str] = None class CompanyResponse(BaseModel): id: str ticker: str legal_name: str exchange: Optional[str] sector: Optional[str] industry: Optional[str] market_cap_bucket: Optional[str] active: bool class AliasCreate(BaseModel): alias: str alias_type: str = "brand" class WatchlistCreate(BaseModel): name: str description: Optional[str] = None class SourceCreate(BaseModel): source_type: str source_name: str config: dict = {} credibility_score: float = 0.5 retention_days: int = 365 access_policy: str = "internal" VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"} # --- Company Endpoints --- @app.post("/companies", response_model=CompanyResponse, status_code=201) async def create_company(body: CompanyCreate): try: row = await pool.fetchrow( """INSERT INTO companies (ticker, legal_name, exchange, sector, industry, market_cap_bucket) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""", body.ticker.upper(), body.legal_name, body.exchange, body.sector, body.industry, body.market_cap_bucket, ) except asyncpg.UniqueViolationError: raise HTTPException(409, f"Company {body.ticker} on {body.exchange} already exists") return dict(row) @app.get("/companies", response_model=List[CompanyResponse]) async def list_companies(active: bool = True): rows = await pool.fetch( "SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE active = $1 ORDER BY ticker", active, ) return [dict(r) for r in rows] @app.get("/companies/{company_id}", response_model=CompanyResponse) async def get_company(company_id: str): row = await pool.fetchrow( "SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE id = $1", company_id, ) if not row: raise HTTPException(404, "Company not found") return dict(row) @app.put("/companies/{company_id}", response_model=CompanyResponse) async def update_company(company_id: str, body: CompanyCreate): row = await pool.fetchrow( """UPDATE companies SET ticker=$2, legal_name=$3, exchange=$4, sector=$5, industry=$6, market_cap_bucket=$7, updated_at=NOW() WHERE id=$1 RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""", company_id, body.ticker.upper(), body.legal_name, body.exchange, body.sector, body.industry, body.market_cap_bucket, ) if not row: raise HTTPException(404, "Company not found") return dict(row) # --- Alias Endpoints --- @app.post("/companies/{company_id}/aliases", status_code=201) async def add_alias(company_id: str, body: AliasCreate): row = await pool.fetchrow( "INSERT INTO company_aliases (company_id, alias, alias_type) VALUES ($1, $2, $3) RETURNING id, alias, alias_type", company_id, body.alias, body.alias_type, ) return dict(row) @app.get("/companies/{company_id}/aliases") async def list_aliases(company_id: str): rows = await pool.fetch( "SELECT id, alias, alias_type FROM company_aliases WHERE company_id = $1", company_id, ) return [dict(r) for r in rows] # --- Watchlist Endpoints --- @app.post("/watchlists", status_code=201) async def create_watchlist(body: WatchlistCreate): try: row = await pool.fetchrow( "INSERT INTO watchlists (name, description) VALUES ($1, $2) RETURNING id, name, description, active", body.name, body.description, ) except asyncpg.UniqueViolationError: raise HTTPException(409, f"Watchlist '{body.name}' already exists") return dict(row) @app.get("/watchlists") async def list_watchlists(): rows = await pool.fetch("SELECT id, name, description, active FROM watchlists ORDER BY name") return [dict(r) for r in rows] @app.post("/watchlists/{watchlist_id}/members/{company_id}", status_code=201) async def add_watchlist_member(watchlist_id: str, company_id: str): try: await pool.execute( "INSERT INTO watchlist_members (watchlist_id, company_id) VALUES ($1, $2)", watchlist_id, company_id, ) except asyncpg.UniqueViolationError: raise HTTPException(409, "Already a member") except asyncpg.ForeignKeyViolationError: raise HTTPException(404, "Watchlist or company not found") return {"status": "added"} @app.get("/watchlists/{watchlist_id}/members") async def list_watchlist_members(watchlist_id: str): rows = await pool.fetch( """SELECT c.id, c.ticker, c.legal_name, c.exchange, c.sector, c.industry, c.market_cap_bucket, c.active FROM companies c JOIN watchlist_members wm ON c.id = wm.company_id WHERE wm.watchlist_id = $1 ORDER BY c.ticker""", watchlist_id, ) return [dict(r) for r in rows] # --- Source Endpoints --- @app.post("/companies/{company_id}/sources", status_code=201) async def add_source(company_id: str, body: SourceCreate): if body.source_type not in VALID_SOURCE_TYPES: raise HTTPException(400, f"Invalid source_type. Must be one of: {VALID_SOURCE_TYPES}") row = await pool.fetchrow( """INSERT INTO sources (company_id, source_type, source_name, config, credibility_score, retention_days, access_policy) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id, source_type, source_name, credibility_score, active""", company_id, body.source_type, body.source_name, body.config, body.credibility_score, body.retention_days, body.access_policy, ) return dict(row) @app.get("/companies/{company_id}/sources") async def list_sources(company_id: str): rows = await pool.fetch( "SELECT id, source_type, source_name, config, credibility_score, retention_days, access_policy, active FROM sources WHERE company_id = $1", company_id, ) return [dict(r) for r in rows]