"""Symbol Registry API - FastAPI application.""" import re from contextlib import asynccontextmanager from typing import List, Optional from urllib.parse import urlparse import asyncpg from fastapi import FastAPI, HTTPException from pydantic import BaseModel, field_validator from services.shared.config import load_config from services.shared.db import get_pg_pool config = load_config() pool: Optional[asyncpg.Pool] = None @asynccontextmanager async def lifespan(app: FastAPI): global pool pool = await get_pg_pool(config) yield await pool.close() app = FastAPI(title="Stonks Oracle - Symbol Registry", lifespan=lifespan) @app.get("/health") async def health(): try: await pool.fetchval("SELECT 1") return {"status": "ok"} except Exception: raise HTTPException(503, "Database unavailable") TICKER_PATTERN = re.compile(r"^[A-Z]{1,10}$") VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"} VALID_ACCESS_POLICIES = {"internal", "public", "restricted"} # --- Request/Response Models --- class CompanyCreate(BaseModel): ticker: str legal_name: str exchange: Optional[str] = None sector: Optional[str] = None industry: Optional[str] = None market_cap_bucket: Optional[str] = None @field_validator("ticker") @classmethod def validate_ticker(cls, v: str) -> str: v = v.upper().strip() if not TICKER_PATTERN.match(v): raise ValueError(f"Ticker must be 1-10 uppercase letters, got: {v}") return v class CompanyResponse(BaseModel): id: str ticker: str legal_name: str exchange: Optional[str] sector: Optional[str] industry: Optional[str] market_cap_bucket: Optional[str] active: bool class AliasCreate(BaseModel): alias: str alias_type: str = "brand" class WatchlistCreate(BaseModel): name: str description: Optional[str] = None class SourceCreate(BaseModel): source_type: str source_name: str config: dict = {} credibility_score: float = 0.5 retention_days: int = 365 access_policy: str = "internal" @field_validator("source_type") @classmethod def validate_source_type(cls, v: str) -> str: if v not in VALID_SOURCE_TYPES: raise ValueError(f"source_type must be one of {VALID_SOURCE_TYPES}") return v @field_validator("access_policy") @classmethod def validate_access_policy(cls, v: str) -> str: if v not in VALID_ACCESS_POLICIES: raise ValueError(f"access_policy must be one of {VALID_ACCESS_POLICIES}") return v @field_validator("config") @classmethod def validate_config_urls(cls, v: dict) -> dict: """Validate any URL fields in the config dict.""" for key in ("base_url", "endpoint", "url"): if key in v and v[key]: parsed = urlparse(str(v[key])) if key == "base_url" and parsed.scheme not in ("http", "https"): raise ValueError(f"config.{key} must be a valid HTTP(S) URL") return v VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"} # --- Company Endpoints --- @app.post("/companies", response_model=CompanyResponse, status_code=201) async def create_company(body: CompanyCreate): try: row = await pool.fetchrow( """INSERT INTO companies (ticker, legal_name, exchange, sector, industry, market_cap_bucket) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""", body.ticker.upper(), body.legal_name, body.exchange, body.sector, body.industry, body.market_cap_bucket, ) except asyncpg.UniqueViolationError: raise HTTPException(409, f"Company {body.ticker} on {body.exchange} already exists") return dict(row) @app.get("/companies", response_model=List[CompanyResponse]) async def list_companies(active: bool = True): rows = await pool.fetch( "SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE active = $1 ORDER BY ticker", active, ) return [dict(r) for r in rows] @app.get("/companies/{company_id}", response_model=CompanyResponse) async def get_company(company_id: str): row = await pool.fetchrow( "SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE id = $1", company_id, ) if not row: raise HTTPException(404, "Company not found") return dict(row) @app.put("/companies/{company_id}", response_model=CompanyResponse) async def update_company(company_id: str, body: CompanyCreate): row = await pool.fetchrow( """UPDATE companies SET ticker=$2, legal_name=$3, exchange=$4, sector=$5, industry=$6, market_cap_bucket=$7, updated_at=NOW() WHERE id=$1 RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""", company_id, body.ticker.upper(), body.legal_name, body.exchange, body.sector, body.industry, body.market_cap_bucket, ) if not row: raise HTTPException(404, "Company not found") return dict(row) # --- Alias Endpoints --- @app.post("/companies/{company_id}/aliases", status_code=201) async def add_alias(company_id: str, body: AliasCreate): row = await pool.fetchrow( "INSERT INTO company_aliases (company_id, alias, alias_type) VALUES ($1, $2, $3) RETURNING id, alias, alias_type", company_id, body.alias, body.alias_type, ) return dict(row) @app.get("/companies/{company_id}/aliases") async def list_aliases(company_id: str): rows = await pool.fetch( "SELECT id, alias, alias_type FROM company_aliases WHERE company_id = $1", company_id, ) return [dict(r) for r in rows] # --- Watchlist Endpoints --- @app.post("/watchlists", status_code=201) async def create_watchlist(body: WatchlistCreate): try: row = await pool.fetchrow( "INSERT INTO watchlists (name, description) VALUES ($1, $2) RETURNING id, name, description, active", body.name, body.description, ) except asyncpg.UniqueViolationError: raise HTTPException(409, f"Watchlist '{body.name}' already exists") return dict(row) @app.get("/watchlists") async def list_watchlists(): rows = await pool.fetch("SELECT id, name, description, active FROM watchlists ORDER BY name") return [dict(r) for r in rows] @app.post("/watchlists/{watchlist_id}/members/{company_id}", status_code=201) async def add_watchlist_member(watchlist_id: str, company_id: str): try: await pool.execute( "INSERT INTO watchlist_members (watchlist_id, company_id) VALUES ($1, $2)", watchlist_id, company_id, ) except asyncpg.UniqueViolationError: raise HTTPException(409, "Already a member") except asyncpg.ForeignKeyViolationError: raise HTTPException(404, "Watchlist or company not found") return {"status": "added"} @app.get("/watchlists/{watchlist_id}/members") async def list_watchlist_members(watchlist_id: str): rows = await pool.fetch( """SELECT c.id, c.ticker, c.legal_name, c.exchange, c.sector, c.industry, c.market_cap_bucket, c.active FROM companies c JOIN watchlist_members wm ON c.id = wm.company_id WHERE wm.watchlist_id = $1 ORDER BY c.ticker""", watchlist_id, ) return [dict(r) for r in rows] # --- Source Endpoints --- @app.post("/companies/{company_id}/sources", status_code=201) async def add_source(company_id: str, body: SourceCreate): # Verify company exists exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id) if not exists: raise HTTPException(404, "Company not found") row = await pool.fetchrow( """INSERT INTO sources (company_id, source_type, source_name, config, credibility_score, retention_days, access_policy) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id, source_type, source_name, credibility_score, active""", company_id, body.source_type, body.source_name, body.config, body.credibility_score, body.retention_days, body.access_policy, ) return dict(row) @app.get("/companies/{company_id}/sources") async def list_sources(company_id: str): rows = await pool.fetch( "SELECT id, source_type, source_name, config, credibility_score, retention_days, access_policy, active FROM sources WHERE company_id = $1", company_id, ) return [dict(r) for r in rows]