184 lines
6.9 KiB
Python
184 lines
6.9 KiB
Python
"""Exposure Profile management endpoints for the Symbol Registry API."""
|
|
import json
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, List
|
|
|
|
import asyncpg
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
router = APIRouter()
|
|
|
|
# --- Valid values ---
|
|
VALID_MARKET_POSITION_TIERS = {"global_leader", "multinational", "regional", "domestic"}
|
|
VALID_SOURCES = {"manual", "inferred"}
|
|
|
|
|
|
# --- Request/Response Models ---
|
|
|
|
class ExposureProfileCreate(BaseModel):
|
|
"""Request body for creating/updating an exposure profile."""
|
|
geographic_revenue_mix: dict[str, float] = Field(default_factory=dict)
|
|
supply_chain_regions: List[str] = Field(default_factory=list)
|
|
key_input_commodities: List[str] = Field(default_factory=list)
|
|
regulatory_jurisdictions: List[str] = Field(default_factory=list)
|
|
market_position_tier: str = "regional"
|
|
export_dependency_pct: float = 0.0
|
|
source: str = "manual"
|
|
confidence: float = 1.0
|
|
|
|
@field_validator("market_position_tier")
|
|
@classmethod
|
|
def validate_tier(cls, v: str) -> str:
|
|
if v not in VALID_MARKET_POSITION_TIERS:
|
|
raise ValueError(f"market_position_tier must be one of {VALID_MARKET_POSITION_TIERS}")
|
|
return v
|
|
|
|
@field_validator("source")
|
|
@classmethod
|
|
def validate_source(cls, v: str) -> str:
|
|
if v not in VALID_SOURCES:
|
|
raise ValueError(f"source must be one of {VALID_SOURCES}")
|
|
return v
|
|
|
|
@field_validator("export_dependency_pct", "confidence")
|
|
@classmethod
|
|
def validate_pct(cls, v: float) -> float:
|
|
if not 0.0 <= v <= 1.0:
|
|
raise ValueError("Value must be between 0.0 and 1.0")
|
|
return v
|
|
|
|
|
|
class ExposureProfileResponse(BaseModel):
|
|
"""Response model for an exposure profile."""
|
|
id: str
|
|
company_id: str
|
|
geographic_revenue_mix: dict[str, float]
|
|
supply_chain_regions: List[str]
|
|
key_input_commodities: List[str]
|
|
regulatory_jurisdictions: List[str]
|
|
market_position_tier: str
|
|
export_dependency_pct: float
|
|
source: str
|
|
confidence: float
|
|
version: int
|
|
active: bool
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
|
|
def _row_to_profile(row: asyncpg.Record) -> dict[str, Any]:
|
|
"""Convert an asyncpg Record to a profile response dict."""
|
|
d = dict(row)
|
|
for k, v in d.items():
|
|
if isinstance(v, uuid.UUID):
|
|
d[k] = str(v)
|
|
# geographic_revenue_mix is stored as JSONB string, parse if needed
|
|
if isinstance(d.get("geographic_revenue_mix"), str):
|
|
d["geographic_revenue_mix"] = json.loads(d["geographic_revenue_mix"])
|
|
return d
|
|
|
|
|
|
def _get_pool(request: Request) -> asyncpg.Pool:
|
|
"""Get the database pool from the app module."""
|
|
from services.symbol_registry.app import pool
|
|
return pool
|
|
|
|
|
|
# --- Endpoints ---
|
|
|
|
@router.get("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
|
|
async def get_exposure_profile(company_id: str, request: Request):
|
|
"""Get the current active exposure profile for a company."""
|
|
pool = _get_pool(request)
|
|
row = await pool.fetchrow(
|
|
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
|
|
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
|
export_dependency_pct, source, confidence, version, active,
|
|
created_at, updated_at
|
|
FROM exposure_profiles
|
|
WHERE company_id = $1 AND active = TRUE
|
|
ORDER BY version DESC
|
|
LIMIT 1""",
|
|
company_id,
|
|
)
|
|
if not row:
|
|
raise HTTPException(404, "No active exposure profile found for this company")
|
|
return _row_to_profile(row)
|
|
|
|
|
|
@router.put("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
|
|
async def upsert_exposure_profile(company_id: str, body: ExposureProfileCreate, request: Request):
|
|
"""Create or update an exposure profile. Archives the previous active version."""
|
|
pool = _get_pool(request)
|
|
|
|
# Verify company exists
|
|
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
|
|
if not exists:
|
|
raise HTTPException(404, "Company not found")
|
|
|
|
async with pool.acquire() as conn:
|
|
async with conn.transaction():
|
|
# Fetch current active profile to get latest version
|
|
current = await conn.fetchrow(
|
|
"""SELECT version FROM exposure_profiles
|
|
WHERE company_id = $1 AND active = TRUE
|
|
ORDER BY version DESC LIMIT 1""",
|
|
company_id,
|
|
)
|
|
|
|
if current:
|
|
new_version = current["version"] + 1
|
|
# Archive the current active profile
|
|
await conn.execute(
|
|
"""UPDATE exposure_profiles
|
|
SET active = FALSE, updated_at = NOW()
|
|
WHERE company_id = $1 AND active = TRUE""",
|
|
company_id,
|
|
)
|
|
else:
|
|
new_version = 1
|
|
|
|
# Insert new profile
|
|
row = await conn.fetchrow(
|
|
"""INSERT INTO exposure_profiles
|
|
(company_id, geographic_revenue_mix, supply_chain_regions,
|
|
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
|
export_dependency_pct, source, confidence, version, active)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
|
|
RETURNING id, company_id, geographic_revenue_mix, supply_chain_regions,
|
|
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
|
export_dependency_pct, source, confidence, version, active,
|
|
created_at, updated_at""",
|
|
company_id,
|
|
json.dumps(body.geographic_revenue_mix),
|
|
body.supply_chain_regions,
|
|
body.key_input_commodities,
|
|
body.regulatory_jurisdictions,
|
|
body.market_position_tier,
|
|
body.export_dependency_pct,
|
|
body.source,
|
|
body.confidence,
|
|
new_version,
|
|
)
|
|
|
|
return _row_to_profile(row)
|
|
|
|
|
|
@router.get("/companies/{company_id}/exposure/history", response_model=List[ExposureProfileResponse])
|
|
async def get_exposure_history(company_id: str, request: Request):
|
|
"""Get all exposure profile versions for a company, ordered by version descending."""
|
|
pool = _get_pool(request)
|
|
rows = await pool.fetch(
|
|
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
|
|
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
|
export_dependency_pct, source, confidence, version, active,
|
|
created_at, updated_at
|
|
FROM exposure_profiles
|
|
WHERE company_id = $1
|
|
ORDER BY version DESC""",
|
|
company_id,
|
|
)
|
|
return [_row_to_profile(r) for r in rows]
|