feat: competitive intelligence & historical pattern matching layer
This commit is contained in:
@@ -0,0 +1,183 @@
|
||||
"""Exposure Profile management endpoints for the Symbol Registry API."""
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
import asyncpg
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# --- Valid values ---
|
||||
VALID_MARKET_POSITION_TIERS = {"global_leader", "multinational", "regional", "domestic"}
|
||||
VALID_SOURCES = {"manual", "inferred"}
|
||||
|
||||
|
||||
# --- Request/Response Models ---
|
||||
|
||||
class ExposureProfileCreate(BaseModel):
|
||||
"""Request body for creating/updating an exposure profile."""
|
||||
geographic_revenue_mix: dict[str, float] = Field(default_factory=dict)
|
||||
supply_chain_regions: List[str] = Field(default_factory=list)
|
||||
key_input_commodities: List[str] = Field(default_factory=list)
|
||||
regulatory_jurisdictions: List[str] = Field(default_factory=list)
|
||||
market_position_tier: str = "regional"
|
||||
export_dependency_pct: float = 0.0
|
||||
source: str = "manual"
|
||||
confidence: float = 1.0
|
||||
|
||||
@field_validator("market_position_tier")
|
||||
@classmethod
|
||||
def validate_tier(cls, v: str) -> str:
|
||||
if v not in VALID_MARKET_POSITION_TIERS:
|
||||
raise ValueError(f"market_position_tier must be one of {VALID_MARKET_POSITION_TIERS}")
|
||||
return v
|
||||
|
||||
@field_validator("source")
|
||||
@classmethod
|
||||
def validate_source(cls, v: str) -> str:
|
||||
if v not in VALID_SOURCES:
|
||||
raise ValueError(f"source must be one of {VALID_SOURCES}")
|
||||
return v
|
||||
|
||||
@field_validator("export_dependency_pct", "confidence")
|
||||
@classmethod
|
||||
def validate_pct(cls, v: float) -> float:
|
||||
if not 0.0 <= v <= 1.0:
|
||||
raise ValueError("Value must be between 0.0 and 1.0")
|
||||
return v
|
||||
|
||||
|
||||
class ExposureProfileResponse(BaseModel):
|
||||
"""Response model for an exposure profile."""
|
||||
id: str
|
||||
company_id: str
|
||||
geographic_revenue_mix: dict[str, float]
|
||||
supply_chain_regions: List[str]
|
||||
key_input_commodities: List[str]
|
||||
regulatory_jurisdictions: List[str]
|
||||
market_position_tier: str
|
||||
export_dependency_pct: float
|
||||
source: str
|
||||
confidence: float
|
||||
version: int
|
||||
active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
def _row_to_profile(row: asyncpg.Record) -> dict[str, Any]:
|
||||
"""Convert an asyncpg Record to a profile response dict."""
|
||||
d = dict(row)
|
||||
for k, v in d.items():
|
||||
if isinstance(v, uuid.UUID):
|
||||
d[k] = str(v)
|
||||
# geographic_revenue_mix is stored as JSONB string, parse if needed
|
||||
if isinstance(d.get("geographic_revenue_mix"), str):
|
||||
d["geographic_revenue_mix"] = json.loads(d["geographic_revenue_mix"])
|
||||
return d
|
||||
|
||||
|
||||
def _get_pool(request: Request) -> asyncpg.Pool:
|
||||
"""Get the database pool from the app module."""
|
||||
from services.symbol_registry.app import pool
|
||||
return pool
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
@router.get("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
|
||||
async def get_exposure_profile(company_id: str, request: Request):
|
||||
"""Get the current active exposure profile for a company."""
|
||||
pool = _get_pool(request)
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version, active,
|
||||
created_at, updated_at
|
||||
FROM exposure_profiles
|
||||
WHERE company_id = $1 AND active = TRUE
|
||||
ORDER BY version DESC
|
||||
LIMIT 1""",
|
||||
company_id,
|
||||
)
|
||||
if not row:
|
||||
raise HTTPException(404, "No active exposure profile found for this company")
|
||||
return _row_to_profile(row)
|
||||
|
||||
|
||||
@router.put("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
|
||||
async def upsert_exposure_profile(company_id: str, body: ExposureProfileCreate, request: Request):
|
||||
"""Create or update an exposure profile. Archives the previous active version."""
|
||||
pool = _get_pool(request)
|
||||
|
||||
# Verify company exists
|
||||
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
|
||||
if not exists:
|
||||
raise HTTPException(404, "Company not found")
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
# Fetch current active profile to get latest version
|
||||
current = await conn.fetchrow(
|
||||
"""SELECT version FROM exposure_profiles
|
||||
WHERE company_id = $1 AND active = TRUE
|
||||
ORDER BY version DESC LIMIT 1""",
|
||||
company_id,
|
||||
)
|
||||
|
||||
if current:
|
||||
new_version = current["version"] + 1
|
||||
# Archive the current active profile
|
||||
await conn.execute(
|
||||
"""UPDATE exposure_profiles
|
||||
SET active = FALSE, updated_at = NOW()
|
||||
WHERE company_id = $1 AND active = TRUE""",
|
||||
company_id,
|
||||
)
|
||||
else:
|
||||
new_version = 1
|
||||
|
||||
# Insert new profile
|
||||
row = await conn.fetchrow(
|
||||
"""INSERT INTO exposure_profiles
|
||||
(company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version, active)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
|
||||
RETURNING id, company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version, active,
|
||||
created_at, updated_at""",
|
||||
company_id,
|
||||
json.dumps(body.geographic_revenue_mix),
|
||||
body.supply_chain_regions,
|
||||
body.key_input_commodities,
|
||||
body.regulatory_jurisdictions,
|
||||
body.market_position_tier,
|
||||
body.export_dependency_pct,
|
||||
body.source,
|
||||
body.confidence,
|
||||
new_version,
|
||||
)
|
||||
|
||||
return _row_to_profile(row)
|
||||
|
||||
|
||||
@router.get("/companies/{company_id}/exposure/history", response_model=List[ExposureProfileResponse])
|
||||
async def get_exposure_history(company_id: str, request: Request):
|
||||
"""Get all exposure profile versions for a company, ordered by version descending."""
|
||||
pool = _get_pool(request)
|
||||
rows = await pool.fetch(
|
||||
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version, active,
|
||||
created_at, updated_at
|
||||
FROM exposure_profiles
|
||||
WHERE company_id = $1
|
||||
ORDER BY version DESC""",
|
||||
company_id,
|
||||
)
|
||||
return [_row_to_profile(r) for r in rows]
|
||||
Reference in New Issue
Block a user