"""Exposure Profile management endpoints for the Symbol Registry API.""" import json import uuid from datetime import datetime from typing import Any, List import asyncpg from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field, field_validator router = APIRouter() # --- Valid values --- VALID_MARKET_POSITION_TIERS = {"global_leader", "multinational", "regional", "domestic"} VALID_SOURCES = {"manual", "inferred"} # --- Request/Response Models --- class ExposureProfileCreate(BaseModel): """Request body for creating/updating an exposure profile.""" geographic_revenue_mix: dict[str, float] = Field(default_factory=dict) supply_chain_regions: List[str] = Field(default_factory=list) key_input_commodities: List[str] = Field(default_factory=list) regulatory_jurisdictions: List[str] = Field(default_factory=list) market_position_tier: str = "regional" export_dependency_pct: float = 0.0 source: str = "manual" confidence: float = 1.0 @field_validator("market_position_tier") @classmethod def validate_tier(cls, v: str) -> str: if v not in VALID_MARKET_POSITION_TIERS: raise ValueError(f"market_position_tier must be one of {VALID_MARKET_POSITION_TIERS}") return v @field_validator("source") @classmethod def validate_source(cls, v: str) -> str: if v not in VALID_SOURCES: raise ValueError(f"source must be one of {VALID_SOURCES}") return v @field_validator("export_dependency_pct", "confidence") @classmethod def validate_pct(cls, v: float) -> float: if not 0.0 <= v <= 1.0: raise ValueError("Value must be between 0.0 and 1.0") return v class ExposureProfileResponse(BaseModel): """Response model for an exposure profile.""" id: str company_id: str geographic_revenue_mix: dict[str, float] supply_chain_regions: List[str] key_input_commodities: List[str] regulatory_jurisdictions: List[str] market_position_tier: str export_dependency_pct: float source: str confidence: float version: int active: bool created_at: datetime updated_at: datetime def _row_to_profile(row: asyncpg.Record) -> dict[str, Any]: """Convert an asyncpg Record to a profile response dict.""" d = dict(row) for k, v in d.items(): if isinstance(v, uuid.UUID): d[k] = str(v) # geographic_revenue_mix is stored as JSONB string, parse if needed if isinstance(d.get("geographic_revenue_mix"), str): d["geographic_revenue_mix"] = json.loads(d["geographic_revenue_mix"]) return d def _get_pool(request: Request) -> asyncpg.Pool: """Get the database pool from the app module.""" from services.symbol_registry.app import pool return pool # --- Endpoints --- @router.get("/companies/{company_id}/exposure", response_model=ExposureProfileResponse) async def get_exposure_profile(company_id: str, request: Request): """Get the current active exposure profile for a company.""" pool = _get_pool(request) row = await pool.fetchrow( """SELECT id, company_id, geographic_revenue_mix, supply_chain_regions, key_input_commodities, regulatory_jurisdictions, market_position_tier, export_dependency_pct, source, confidence, version, active, created_at, updated_at FROM exposure_profiles WHERE company_id = $1 AND active = TRUE ORDER BY version DESC LIMIT 1""", company_id, ) if not row: raise HTTPException(404, "No active exposure profile found for this company") return _row_to_profile(row) @router.put("/companies/{company_id}/exposure", response_model=ExposureProfileResponse) async def upsert_exposure_profile(company_id: str, body: ExposureProfileCreate, request: Request): """Create or update an exposure profile. Archives the previous active version.""" pool = _get_pool(request) # Verify company exists exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id) if not exists: raise HTTPException(404, "Company not found") async with pool.acquire() as conn: async with conn.transaction(): # Fetch current active profile to get latest version current = await conn.fetchrow( """SELECT version FROM exposure_profiles WHERE company_id = $1 AND active = TRUE ORDER BY version DESC LIMIT 1""", company_id, ) if current: new_version = current["version"] + 1 # Archive the current active profile await conn.execute( """UPDATE exposure_profiles SET active = FALSE, updated_at = NOW() WHERE company_id = $1 AND active = TRUE""", company_id, ) else: new_version = 1 # Insert new profile row = await conn.fetchrow( """INSERT INTO exposure_profiles (company_id, geographic_revenue_mix, supply_chain_regions, key_input_commodities, regulatory_jurisdictions, market_position_tier, export_dependency_pct, source, confidence, version, active) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE) RETURNING id, company_id, geographic_revenue_mix, supply_chain_regions, key_input_commodities, regulatory_jurisdictions, market_position_tier, export_dependency_pct, source, confidence, version, active, created_at, updated_at""", company_id, json.dumps(body.geographic_revenue_mix), body.supply_chain_regions, body.key_input_commodities, body.regulatory_jurisdictions, body.market_position_tier, body.export_dependency_pct, body.source, body.confidence, new_version, ) return _row_to_profile(row) @router.get("/companies/{company_id}/exposure/history", response_model=List[ExposureProfileResponse]) async def get_exposure_history(company_id: str, request: Request): """Get all exposure profile versions for a company, ordered by version descending.""" pool = _get_pool(request) rows = await pool.fetch( """SELECT id, company_id, geographic_revenue_mix, supply_chain_regions, key_input_commodities, regulatory_jurisdictions, market_position_tier, export_dependency_pct, source, confidence, version, active, created_at, updated_at FROM exposure_profiles WHERE company_id = $1 ORDER BY version DESC""", company_id, ) return [_row_to_profile(r) for r in rows]