7394d241c9
- Enhanced CompanyCreate with ticker format validation (1-10 uppercase letters) - Enhanced SourceCreate with pydantic validators for source_type, access_policy, config URLs - Added /health endpoint to symbol registry - Seed data: 10 companies (AAPL, MSFT, NVDA, AMZN, GOOGL, JPM, JNJ, XOM, TSLA, META) - Seed sources: Alpha Vantage (market), NewsAPI (news), SEC EDGAR (filings), Alpaca (paper trading) - Seed watchlist: 'Starter 10' with all companies and aliases - Added flake.nix dev shell (nixos-25.11) with Python 3.12, ruff, pytest, kubectl, helm - 30 passing tests, lint clean, Docker build verified
260 lines
8.5 KiB
Python
260 lines
8.5 KiB
Python
"""Symbol Registry API - FastAPI application."""
|
|
import re
|
|
from contextlib import asynccontextmanager
|
|
from typing import List, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
import asyncpg
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel, field_validator
|
|
|
|
from services.shared.config import load_config
|
|
from services.shared.db import get_pg_pool
|
|
|
|
config = load_config()
|
|
pool: Optional[asyncpg.Pool] = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global pool
|
|
pool = await get_pg_pool(config)
|
|
yield
|
|
await pool.close()
|
|
|
|
|
|
app = FastAPI(title="Stonks Oracle - Symbol Registry", lifespan=lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
try:
|
|
await pool.fetchval("SELECT 1")
|
|
return {"status": "ok"}
|
|
except Exception:
|
|
raise HTTPException(503, "Database unavailable")
|
|
|
|
TICKER_PATTERN = re.compile(r"^[A-Z]{1,10}$")
|
|
VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"}
|
|
VALID_ACCESS_POLICIES = {"internal", "public", "restricted"}
|
|
|
|
|
|
# --- Request/Response Models ---
|
|
|
|
class CompanyCreate(BaseModel):
|
|
ticker: str
|
|
legal_name: str
|
|
exchange: Optional[str] = None
|
|
sector: Optional[str] = None
|
|
industry: Optional[str] = None
|
|
market_cap_bucket: Optional[str] = None
|
|
|
|
@field_validator("ticker")
|
|
@classmethod
|
|
def validate_ticker(cls, v: str) -> str:
|
|
v = v.upper().strip()
|
|
if not TICKER_PATTERN.match(v):
|
|
raise ValueError(f"Ticker must be 1-10 uppercase letters, got: {v}")
|
|
return v
|
|
|
|
|
|
class CompanyResponse(BaseModel):
|
|
id: str
|
|
ticker: str
|
|
legal_name: str
|
|
exchange: Optional[str]
|
|
sector: Optional[str]
|
|
industry: Optional[str]
|
|
market_cap_bucket: Optional[str]
|
|
active: bool
|
|
|
|
|
|
class AliasCreate(BaseModel):
|
|
alias: str
|
|
alias_type: str = "brand"
|
|
|
|
|
|
class WatchlistCreate(BaseModel):
|
|
name: str
|
|
description: Optional[str] = None
|
|
|
|
|
|
class SourceCreate(BaseModel):
|
|
source_type: str
|
|
source_name: str
|
|
config: dict = {}
|
|
credibility_score: float = 0.5
|
|
retention_days: int = 365
|
|
access_policy: str = "internal"
|
|
|
|
@field_validator("source_type")
|
|
@classmethod
|
|
def validate_source_type(cls, v: str) -> str:
|
|
if v not in VALID_SOURCE_TYPES:
|
|
raise ValueError(f"source_type must be one of {VALID_SOURCE_TYPES}")
|
|
return v
|
|
|
|
@field_validator("access_policy")
|
|
@classmethod
|
|
def validate_access_policy(cls, v: str) -> str:
|
|
if v not in VALID_ACCESS_POLICIES:
|
|
raise ValueError(f"access_policy must be one of {VALID_ACCESS_POLICIES}")
|
|
return v
|
|
|
|
@field_validator("config")
|
|
@classmethod
|
|
def validate_config_urls(cls, v: dict) -> dict:
|
|
"""Validate any URL fields in the config dict."""
|
|
for key in ("base_url", "endpoint", "url"):
|
|
if key in v and v[key]:
|
|
parsed = urlparse(str(v[key]))
|
|
if key == "base_url" and parsed.scheme not in ("http", "https"):
|
|
raise ValueError(f"config.{key} must be a valid HTTP(S) URL")
|
|
return v
|
|
|
|
|
|
VALID_SOURCE_TYPES = {"market_api", "news_api", "filings_api", "web_scrape", "broker"}
|
|
|
|
|
|
# --- Company Endpoints ---
|
|
|
|
@app.post("/companies", response_model=CompanyResponse, status_code=201)
|
|
async def create_company(body: CompanyCreate):
|
|
try:
|
|
row = await pool.fetchrow(
|
|
"""INSERT INTO companies (ticker, legal_name, exchange, sector, industry, market_cap_bucket)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""",
|
|
body.ticker.upper(), body.legal_name, body.exchange, body.sector,
|
|
body.industry, body.market_cap_bucket,
|
|
)
|
|
except asyncpg.UniqueViolationError:
|
|
raise HTTPException(409, f"Company {body.ticker} on {body.exchange} already exists")
|
|
return dict(row)
|
|
|
|
|
|
@app.get("/companies", response_model=List[CompanyResponse])
|
|
async def list_companies(active: bool = True):
|
|
rows = await pool.fetch(
|
|
"SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE active = $1 ORDER BY ticker",
|
|
active,
|
|
)
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
@app.get("/companies/{company_id}", response_model=CompanyResponse)
|
|
async def get_company(company_id: str):
|
|
row = await pool.fetchrow(
|
|
"SELECT id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active FROM companies WHERE id = $1",
|
|
company_id,
|
|
)
|
|
if not row:
|
|
raise HTTPException(404, "Company not found")
|
|
return dict(row)
|
|
|
|
|
|
@app.put("/companies/{company_id}", response_model=CompanyResponse)
|
|
async def update_company(company_id: str, body: CompanyCreate):
|
|
row = await pool.fetchrow(
|
|
"""UPDATE companies SET ticker=$2, legal_name=$3, exchange=$4, sector=$5, industry=$6, market_cap_bucket=$7, updated_at=NOW()
|
|
WHERE id=$1
|
|
RETURNING id, ticker, legal_name, exchange, sector, industry, market_cap_bucket, active""",
|
|
company_id, body.ticker.upper(), body.legal_name, body.exchange,
|
|
body.sector, body.industry, body.market_cap_bucket,
|
|
)
|
|
if not row:
|
|
raise HTTPException(404, "Company not found")
|
|
return dict(row)
|
|
|
|
|
|
# --- Alias Endpoints ---
|
|
|
|
@app.post("/companies/{company_id}/aliases", status_code=201)
|
|
async def add_alias(company_id: str, body: AliasCreate):
|
|
row = await pool.fetchrow(
|
|
"INSERT INTO company_aliases (company_id, alias, alias_type) VALUES ($1, $2, $3) RETURNING id, alias, alias_type",
|
|
company_id, body.alias, body.alias_type,
|
|
)
|
|
return dict(row)
|
|
|
|
|
|
@app.get("/companies/{company_id}/aliases")
|
|
async def list_aliases(company_id: str):
|
|
rows = await pool.fetch(
|
|
"SELECT id, alias, alias_type FROM company_aliases WHERE company_id = $1",
|
|
company_id,
|
|
)
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
# --- Watchlist Endpoints ---
|
|
|
|
@app.post("/watchlists", status_code=201)
|
|
async def create_watchlist(body: WatchlistCreate):
|
|
try:
|
|
row = await pool.fetchrow(
|
|
"INSERT INTO watchlists (name, description) VALUES ($1, $2) RETURNING id, name, description, active",
|
|
body.name, body.description,
|
|
)
|
|
except asyncpg.UniqueViolationError:
|
|
raise HTTPException(409, f"Watchlist '{body.name}' already exists")
|
|
return dict(row)
|
|
|
|
|
|
@app.get("/watchlists")
|
|
async def list_watchlists():
|
|
rows = await pool.fetch("SELECT id, name, description, active FROM watchlists ORDER BY name")
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
@app.post("/watchlists/{watchlist_id}/members/{company_id}", status_code=201)
|
|
async def add_watchlist_member(watchlist_id: str, company_id: str):
|
|
try:
|
|
await pool.execute(
|
|
"INSERT INTO watchlist_members (watchlist_id, company_id) VALUES ($1, $2)",
|
|
watchlist_id, company_id,
|
|
)
|
|
except asyncpg.UniqueViolationError:
|
|
raise HTTPException(409, "Already a member")
|
|
except asyncpg.ForeignKeyViolationError:
|
|
raise HTTPException(404, "Watchlist or company not found")
|
|
return {"status": "added"}
|
|
|
|
|
|
@app.get("/watchlists/{watchlist_id}/members")
|
|
async def list_watchlist_members(watchlist_id: str):
|
|
rows = await pool.fetch(
|
|
"""SELECT c.id, c.ticker, c.legal_name, c.exchange, c.sector, c.industry, c.market_cap_bucket, c.active
|
|
FROM companies c JOIN watchlist_members wm ON c.id = wm.company_id
|
|
WHERE wm.watchlist_id = $1 ORDER BY c.ticker""",
|
|
watchlist_id,
|
|
)
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
# --- Source Endpoints ---
|
|
|
|
@app.post("/companies/{company_id}/sources", status_code=201)
|
|
async def add_source(company_id: str, body: SourceCreate):
|
|
# Verify company exists
|
|
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
|
|
if not exists:
|
|
raise HTTPException(404, "Company not found")
|
|
row = await pool.fetchrow(
|
|
"""INSERT INTO sources (company_id, source_type, source_name, config, credibility_score, retention_days, access_policy)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
RETURNING id, source_type, source_name, credibility_score, active""",
|
|
company_id, body.source_type, body.source_name,
|
|
body.config, body.credibility_score, body.retention_days, body.access_policy,
|
|
)
|
|
return dict(row)
|
|
|
|
|
|
@app.get("/companies/{company_id}/sources")
|
|
async def list_sources(company_id: str):
|
|
rows = await pool.fetch(
|
|
"SELECT id, source_type, source_name, config, credibility_score, retention_days, access_policy, active FROM sources WHERE company_id = $1",
|
|
company_id,
|
|
)
|
|
return [dict(r) for r in rows]
|