Files
stonks-oracle/services/symbol_registry/exposure.py
T

184 lines
6.9 KiB
Python

"""Exposure Profile management endpoints for the Symbol Registry API."""
import json
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
router = APIRouter()
# --- Valid values ---
VALID_MARKET_POSITION_TIERS = {"global_leader", "multinational", "regional", "domestic"}
VALID_SOURCES = {"manual", "inferred"}
# --- Request/Response Models ---
class ExposureProfileCreate(BaseModel):
"""Request body for creating/updating an exposure profile."""
geographic_revenue_mix: dict[str, float] = Field(default_factory=dict)
supply_chain_regions: List[str] = Field(default_factory=list)
key_input_commodities: List[str] = Field(default_factory=list)
regulatory_jurisdictions: List[str] = Field(default_factory=list)
market_position_tier: str = "regional"
export_dependency_pct: float = 0.0
source: str = "manual"
confidence: float = 1.0
@field_validator("market_position_tier")
@classmethod
def validate_tier(cls, v: str) -> str:
if v not in VALID_MARKET_POSITION_TIERS:
raise ValueError(f"market_position_tier must be one of {VALID_MARKET_POSITION_TIERS}")
return v
@field_validator("source")
@classmethod
def validate_source(cls, v: str) -> str:
if v not in VALID_SOURCES:
raise ValueError(f"source must be one of {VALID_SOURCES}")
return v
@field_validator("export_dependency_pct", "confidence")
@classmethod
def validate_pct(cls, v: float) -> float:
if not 0.0 <= v <= 1.0:
raise ValueError("Value must be between 0.0 and 1.0")
return v
class ExposureProfileResponse(BaseModel):
"""Response model for an exposure profile."""
id: str
company_id: str
geographic_revenue_mix: dict[str, float]
supply_chain_regions: List[str]
key_input_commodities: List[str]
regulatory_jurisdictions: List[str]
market_position_tier: str
export_dependency_pct: float
source: str
confidence: float
version: int
active: bool
created_at: datetime
updated_at: datetime
def _row_to_profile(row: asyncpg.Record) -> dict[str, Any]:
"""Convert an asyncpg Record to a profile response dict."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
# geographic_revenue_mix is stored as JSONB string, parse if needed
if isinstance(d.get("geographic_revenue_mix"), str):
d["geographic_revenue_mix"] = json.loads(d["geographic_revenue_mix"])
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
# --- Endpoints ---
@router.get("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
async def get_exposure_profile(company_id: str, request: Request):
"""Get the current active exposure profile for a company."""
pool = _get_pool(request)
row = await pool.fetchrow(
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at
FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC
LIMIT 1""",
company_id,
)
if not row:
raise HTTPException(404, "No active exposure profile found for this company")
return _row_to_profile(row)
@router.put("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
async def upsert_exposure_profile(company_id: str, body: ExposureProfileCreate, request: Request):
"""Create or update an exposure profile. Archives the previous active version."""
pool = _get_pool(request)
# Verify company exists
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
if not exists:
raise HTTPException(404, "Company not found")
async with pool.acquire() as conn:
async with conn.transaction():
# Fetch current active profile to get latest version
current = await conn.fetchrow(
"""SELECT version FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC LIMIT 1""",
company_id,
)
if current:
new_version = current["version"] + 1
# Archive the current active profile
await conn.execute(
"""UPDATE exposure_profiles
SET active = FALSE, updated_at = NOW()
WHERE company_id = $1 AND active = TRUE""",
company_id,
)
else:
new_version = 1
# Insert new profile
row = await conn.fetchrow(
"""INSERT INTO exposure_profiles
(company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
RETURNING id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at""",
company_id,
json.dumps(body.geographic_revenue_mix),
body.supply_chain_regions,
body.key_input_commodities,
body.regulatory_jurisdictions,
body.market_position_tier,
body.export_dependency_pct,
body.source,
body.confidence,
new_version,
)
return _row_to_profile(row)
@router.get("/companies/{company_id}/exposure/history", response_model=List[ExposureProfileResponse])
async def get_exposure_history(company_id: str, request: Request):
"""Get all exposure profile versions for a company, ordered by version descending."""
pool = _get_pool(request)
rows = await pool.fetch(
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at
FROM exposure_profiles
WHERE company_id = $1
ORDER BY version DESC""",
company_id,
)
return [_row_to_profile(r) for r in rows]