227 lines
8.0 KiB
Python
227 lines
8.0 KiB
Python
"""Competitor Relationship management endpoints for the Symbol Registry API."""
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, List
|
|
|
|
import asyncpg
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from services.shared.audit import record_audit_event
|
|
|
|
router = APIRouter()
|
|
|
|
# --- Valid values ---
|
|
VALID_RELATIONSHIP_TYPES = {"direct_rival", "same_sector", "overlapping_products", "supply_chain_adjacent"}
|
|
VALID_SOURCES = {"manual", "inferred"}
|
|
|
|
|
|
# --- Request/Response Models ---
|
|
|
|
class CompetitorRelationshipCreate(BaseModel):
|
|
"""Request body for creating a competitor relationship."""
|
|
company_b_id: str
|
|
relationship_type: str
|
|
strength: float = Field(default=0.5, ge=0, le=1)
|
|
bidirectional: bool = True
|
|
source: str = "manual"
|
|
|
|
@field_validator("relationship_type")
|
|
@classmethod
|
|
def validate_relationship_type(cls, v: str) -> str:
|
|
if v not in VALID_RELATIONSHIP_TYPES:
|
|
raise ValueError(f"relationship_type must be one of {VALID_RELATIONSHIP_TYPES}")
|
|
return v
|
|
|
|
@field_validator("source")
|
|
@classmethod
|
|
def validate_source(cls, v: str) -> str:
|
|
if v not in VALID_SOURCES:
|
|
raise ValueError(f"source must be one of {VALID_SOURCES}")
|
|
return v
|
|
|
|
|
|
class CompetitorRelationship(BaseModel):
|
|
"""Response model for a competitor relationship."""
|
|
id: str
|
|
company_a_id: str
|
|
company_b_id: str
|
|
relationship_type: str
|
|
strength: float
|
|
bidirectional: bool
|
|
source: str
|
|
active: bool
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
|
|
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
|
|
"""Convert asyncpg Record to dict with UUID→str coercion."""
|
|
d = dict(row)
|
|
for k, v in d.items():
|
|
if isinstance(v, uuid.UUID):
|
|
d[k] = str(v)
|
|
return d
|
|
|
|
|
|
def _get_pool(request: Request) -> asyncpg.Pool:
|
|
"""Get the database pool from the app module."""
|
|
from services.symbol_registry.app import pool
|
|
return pool
|
|
|
|
|
|
async def _company_exists(pool: asyncpg.Pool, company_id: str) -> bool:
|
|
"""Check if a company exists."""
|
|
return await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id) is not None
|
|
|
|
|
|
# --- Endpoints ---
|
|
|
|
@router.post("/companies/{company_id}/competitors", response_model=CompetitorRelationship, status_code=201)
|
|
async def create_competitor(company_id: str, body: CompetitorRelationshipCreate, request: Request):
|
|
"""Create a competitor relationship for a company."""
|
|
pool = _get_pool(request)
|
|
|
|
# Self-referencing check
|
|
if company_id == body.company_b_id:
|
|
raise HTTPException(400, "A company cannot be its own competitor")
|
|
|
|
# Check both companies exist
|
|
if not await _company_exists(pool, company_id):
|
|
raise HTTPException(404, "Company not found")
|
|
if not await _company_exists(pool, body.company_b_id):
|
|
raise HTTPException(404, "Competitor company not found")
|
|
|
|
try:
|
|
row = await pool.fetchrow(
|
|
"""INSERT INTO competitor_relationships
|
|
(company_a_id, company_b_id, relationship_type, strength, bidirectional, source)
|
|
VALUES ($1, $2, $3, $4, $5, $6)
|
|
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
|
|
bidirectional, source, active, created_at, updated_at""",
|
|
company_id, body.company_b_id, body.relationship_type,
|
|
body.strength, body.bidirectional, body.source,
|
|
)
|
|
except asyncpg.UniqueViolationError:
|
|
raise HTTPException(409, "An active competitor relationship already exists between these companies")
|
|
|
|
result = _row_dict(row)
|
|
|
|
await record_audit_event(
|
|
pool,
|
|
event_type="competitor_relationship.created",
|
|
entity_type="competitor_relationship",
|
|
entity_id=result["id"],
|
|
data={
|
|
"company_a_id": company_id,
|
|
"company_b_id": body.company_b_id,
|
|
"relationship_type": body.relationship_type,
|
|
"strength": body.strength,
|
|
"bidirectional": body.bidirectional,
|
|
"source": body.source,
|
|
},
|
|
actor="operator",
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
@router.get("/companies/{company_id}/competitors", response_model=List[CompetitorRelationship])
|
|
async def list_competitors(company_id: str, request: Request):
|
|
"""List active competitor relationships for a company, ordered by strength descending."""
|
|
pool = _get_pool(request)
|
|
|
|
if not await _company_exists(pool, company_id):
|
|
raise HTTPException(404, "Company not found")
|
|
|
|
rows = await pool.fetch(
|
|
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
|
|
bidirectional, source, active, created_at, updated_at
|
|
FROM competitor_relationships
|
|
WHERE (company_a_id = $1 OR company_b_id = $1) AND active = TRUE
|
|
ORDER BY strength DESC""",
|
|
company_id,
|
|
)
|
|
return [_row_dict(r) for r in rows]
|
|
|
|
|
|
@router.put("/companies/{company_id}/competitors/{relationship_id}", response_model=CompetitorRelationship)
|
|
async def update_competitor(company_id: str, relationship_id: str, body: CompetitorRelationshipCreate, request: Request):
|
|
"""Update a competitor relationship with audit event recording previous state."""
|
|
pool = _get_pool(request)
|
|
|
|
# Fetch existing relationship
|
|
existing = await pool.fetchrow(
|
|
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
|
|
bidirectional, source, active, created_at, updated_at
|
|
FROM competitor_relationships
|
|
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2)""",
|
|
relationship_id, company_id,
|
|
)
|
|
if not existing:
|
|
raise HTTPException(404, "Competitor relationship not found")
|
|
|
|
previous_state = _row_dict(existing)
|
|
|
|
row = await pool.fetchrow(
|
|
"""UPDATE competitor_relationships
|
|
SET relationship_type = $2, strength = $3, bidirectional = $4, source = $5, updated_at = NOW()
|
|
WHERE id = $1
|
|
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
|
|
bidirectional, source, active, created_at, updated_at""",
|
|
relationship_id, body.relationship_type, body.strength, body.bidirectional, body.source,
|
|
)
|
|
|
|
result = _row_dict(row)
|
|
|
|
await record_audit_event(
|
|
pool,
|
|
event_type="competitor_relationship.updated",
|
|
entity_type="competitor_relationship",
|
|
entity_id=result["id"],
|
|
data={
|
|
"previous_state": {
|
|
"relationship_type": previous_state["relationship_type"],
|
|
"strength": previous_state["strength"],
|
|
"bidirectional": previous_state["bidirectional"],
|
|
"source": previous_state["source"],
|
|
},
|
|
"new_state": {
|
|
"relationship_type": body.relationship_type,
|
|
"strength": body.strength,
|
|
"bidirectional": body.bidirectional,
|
|
"source": body.source,
|
|
},
|
|
},
|
|
actor="operator",
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
@router.delete("/companies/{company_id}/competitors/{relationship_id}", status_code=200)
|
|
async def delete_competitor(company_id: str, relationship_id: str, request: Request):
|
|
"""Soft-delete a competitor relationship (set active=False), preserve row."""
|
|
pool = _get_pool(request)
|
|
|
|
row = await pool.fetchrow(
|
|
"""UPDATE competitor_relationships
|
|
SET active = FALSE, updated_at = NOW()
|
|
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2) AND active = TRUE
|
|
RETURNING id""",
|
|
relationship_id, company_id,
|
|
)
|
|
if not row:
|
|
raise HTTPException(404, "Active competitor relationship not found")
|
|
|
|
await record_audit_event(
|
|
pool,
|
|
event_type="competitor_relationship.deleted",
|
|
entity_type="competitor_relationship",
|
|
entity_id=str(row["id"]),
|
|
data={"company_id": company_id, "soft_deleted": True},
|
|
actor="operator",
|
|
)
|
|
|
|
return {"status": "deleted", "id": str(row["id"])}
|