Files
stonks-oracle/services/symbol_registry/app.py
T

262 lines
8.7 KiB
Python

"""Symbol Registry API - FastAPI application."""
import re
from contextlib import asynccontextmanager
from typing import List, Optional
from urllib.parse import urlparse
import asyncpg
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, field_validator
from services.shared.config import load_config
from services.shared.db import get_pg_pool
from services.shared.logging import setup_logging
config = load_config()
pool: Optional[asyncpg.Pool] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global pool
setup_logging("symbol_registry", level=config.log_level, json_output=config.json_logs)
pool = await get_pg_pool(config)
yield
await pool.close()
app = FastAPI(title="Stonks Oracle - Symbol Registry", lifespan=lifespan)
@app.get("/health")
async def health():
try:
await pool.fetchval("SELECT 1")
return {"status": "ok"}
except Exception:
raise HTTPException(503, "Database unavailable")
TICKER_PATTERN = re.compile(r"^[A-Z]{1,10}$")
VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"}
VALID_ACCESS_POLICIES = {"internal", "public", "restricted"}
# --- Request/Response Models ---
class CompanyCreate(BaseModel):
ticker: str
legal_name: str
exchange: Optional[str] = None
sector: Optional[str] = None
industry: Optional[str] = None
market_cap_bucket: Optional[str] = None
@field_validator("ticker")
@classmethod
def validate_ticker(cls, v: str) -> str:
v = v.upper().strip()
if not TICKER_PATTERN.match(v):
raise ValueError(f"Ticker must be 1-10 uppercase letters, got: {v}")
return v
class CompanyResponse(BaseModel):
id: str
ticker: str
legal_name: str
exchange: Optional[str]
sector: Optional[str]
industry: Optional[str]
market_cap_bucket: Optional[str]
active: bool
class AliasCreate(BaseModel):
alias: str
alias_type: str = "brand"
class WatchlistCreate(BaseModel):
name: str
description: Optional[str] = None
class SourceCreate(BaseModel):
source_type: str
source_name: str
config: dict = {}
credibility_score: float = 0.5
retention_days: int = 365
access_policy: str = "internal"
@field_validator("source_type")
@classmethod
def validate_source_type(cls, v: str) -> str:
if v not in VALID_SOURCE_TYPES:
raise ValueError(f"source_type must be one of {VALID_SOURCE_TYPES}")
return v
@field_validator("access_policy")
@classmethod
def validate_access_policy(cls, v: str) -> str:
if v not in VALID_ACCESS_POLICIES:
raise ValueError(f"access_policy must be one of {VALID_ACCESS_POLICIES}")
return v
@field_validator("config")
@classmethod
def validate_config_urls(cls, v: dict) -> dict:
"""Validate any URL fields in the config dict."""
for key in ("base_url", "endpoint", "url"):
if key in v and v[key]:
parsed = urlparse(str(v[key]))
if key == "base_url" and parsed.scheme not in ("http", "https"):
raise ValueError(f"config.{key} must be a valid HTTP(S) URL")
return v
VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"}
# --- Company Endpoints ---
@app.post("/companies", response_model=CompanyResponse, status_code=201)
async def create_company(body: CompanyCreate):
try:
row = await pool.fetchrow(
"""INSERT INTO companies (ticker, legal_name, exchange, sector, industry, market_cap_bucket)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""",
body.ticker.upper(), body.legal_name, body.exchange, body.sector,
body.industry, body.market_cap_bucket,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Company {body.ticker} on {body.exchange} already exists")
return dict(row)
@app.get("/companies", response_model=List[CompanyResponse])
async def list_companies(active: bool = True):
rows = await pool.fetch(
"SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE active = $1 ORDER BY ticker",
active,
)
return [dict(r) for r in rows]
@app.get("/companies/{company_id}", response_model=CompanyResponse)
async def get_company(company_id: str):
row = await pool.fetchrow(
"SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE id = $1",
company_id,
)
if not row:
raise HTTPException(404, "Company not found")
return dict(row)
@app.put("/companies/{company_id}", response_model=CompanyResponse)
async def update_company(company_id: str, body: CompanyCreate):
row = await pool.fetchrow(
"""UPDATE companies SET ticker=$2, legal_name=$3, exchange=$4, sector=$5, industry=$6, market_cap_bucket=$7, updated_at=NOW()
WHERE id=$1
RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""",
company_id, body.ticker.upper(), body.legal_name, body.exchange,
body.sector, body.industry, body.market_cap_bucket,
)
if not row:
raise HTTPException(404, "Company not found")
return dict(row)
# --- Alias Endpoints ---
@app.post("/companies/{company_id}/aliases", status_code=201)
async def add_alias(company_id: str, body: AliasCreate):
row = await pool.fetchrow(
"INSERT INTO company_aliases (company_id, alias, alias_type) VALUES ($1, $2, $3) RETURNING id, alias, alias_type",
company_id, body.alias, body.alias_type,
)
return dict(row)
@app.get("/companies/{company_id}/aliases")
async def list_aliases(company_id: str):
rows = await pool.fetch(
"SELECT id, alias, alias_type FROM company_aliases WHERE company_id = $1",
company_id,
)
return [dict(r) for r in rows]
# --- Watchlist Endpoints ---
@app.post("/watchlists", status_code=201)
async def create_watchlist(body: WatchlistCreate):
try:
row = await pool.fetchrow(
"INSERT INTO watchlists (name, description) VALUES ($1, $2) RETURNING id, name, description, active",
body.name, body.description,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, f"Watchlist '{body.name}' already exists")
return dict(row)
@app.get("/watchlists")
async def list_watchlists():
rows = await pool.fetch("SELECT id, name, description, active FROM watchlists ORDER BY name")
return [dict(r) for r in rows]
@app.post("/watchlists/{watchlist_id}/members/{company_id}", status_code=201)
async def add_watchlist_member(watchlist_id: str, company_id: str):
try:
await pool.execute(
"INSERT INTO watchlist_members (watchlist_id, company_id) VALUES ($1, $2)",
watchlist_id, company_id,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, "Already a member")
except asyncpg.ForeignKeyViolationError:
raise HTTPException(404, "Watchlist or company not found")
return {"status": "added"}
@app.get("/watchlists/{watchlist_id}/members")
async def list_watchlist_members(watchlist_id: str):
rows = await pool.fetch(
"""SELECT c.id, c.ticker, c.legal_name, c.exchange, c.sector, c.industry, c.market_cap_bucket, c.active
FROM companies c JOIN watchlist_members wm ON c.id = wm.company_id
WHERE wm.watchlist_id = $1 ORDER BY c.ticker""",
watchlist_id,
)
return [dict(r) for r in rows]
# --- Source Endpoints ---
@app.post("/companies/{company_id}/sources", status_code=201)
async def add_source(company_id: str, body: SourceCreate):
# Verify company exists
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
if not exists:
raise HTTPException(404, "Company not found")
row = await pool.fetchrow(
"""INSERT INTO sources (company_id, source_type, source_name, config, credibility_score, retention_days, access_policy)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id, source_type, source_name, credibility_score, active""",
company_id, body.source_type, body.source_name,
body.config, body.credibility_score, body.retention_days, body.access_policy,
)
return dict(row)
@app.get("/companies/{company_id}/sources")
async def list_sources(company_id: str):
rows = await pool.fetch(
"SELECT id, source_type, source_name, config, credibility_score, retention_days, access_policy, active FROM sources WHERE company_id = $1",
company_id,
)
return [dict(r) for r in rows]