"""Competitor Relationship management endpoints for the Symbol Registry API.""" import uuid from datetime import datetime from typing import Any, List import asyncpg from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field, field_validator from services.shared.audit import record_audit_event router = APIRouter() # --- Valid values --- VALID_RELATIONSHIP_TYPES = {"direct_rival", "same_sector", "overlapping_products", "supply_chain_adjacent"} VALID_SOURCES = {"manual", "inferred"} # --- Request/Response Models --- class CompetitorRelationshipCreate(BaseModel): """Request body for creating a competitor relationship.""" company_b_id: str relationship_type: str strength: float = Field(default=0.5, ge=0, le=1) bidirectional: bool = True source: str = "manual" @field_validator("relationship_type") @classmethod def validate_relationship_type(cls, v: str) -> str: if v not in VALID_RELATIONSHIP_TYPES: raise ValueError(f"relationship_type must be one of {VALID_RELATIONSHIP_TYPES}") return v @field_validator("source") @classmethod def validate_source(cls, v: str) -> str: if v not in VALID_SOURCES: raise ValueError(f"source must be one of {VALID_SOURCES}") return v class CompetitorRelationship(BaseModel): """Response model for a competitor relationship.""" id: str company_a_id: str company_b_id: str relationship_type: str strength: float bidirectional: bool source: str active: bool created_at: datetime updated_at: datetime def _row_dict(row: asyncpg.Record) -> dict[str, Any]: """Convert asyncpg Record to dict with UUID→str coercion.""" d = dict(row) for k, v in d.items(): if isinstance(v, uuid.UUID): d[k] = str(v) return d def _get_pool(request: Request) -> asyncpg.Pool: """Get the database pool from the app module.""" from services.symbol_registry.app import pool return pool async def _company_exists(pool: asyncpg.Pool, company_id: str) -> bool: """Check if a company exists.""" return await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id) is not None # --- Endpoints --- @router.post("/companies/{company_id}/competitors", response_model=CompetitorRelationship, status_code=201) async def create_competitor(company_id: str, body: CompetitorRelationshipCreate, request: Request): """Create a competitor relationship for a company.""" pool = _get_pool(request) # Self-referencing check if company_id == body.company_b_id: raise HTTPException(400, "A company cannot be its own competitor") # Check both companies exist if not await _company_exists(pool, company_id): raise HTTPException(404, "Company not found") if not await _company_exists(pool, body.company_b_id): raise HTTPException(404, "Competitor company not found") try: row = await pool.fetchrow( """INSERT INTO competitor_relationships (company_a_id, company_b_id, relationship_type, strength, bidirectional, source) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, company_a_id, company_b_id, relationship_type, strength, bidirectional, source, active, created_at, updated_at""", company_id, body.company_b_id, body.relationship_type, body.strength, body.bidirectional, body.source, ) except asyncpg.UniqueViolationError: raise HTTPException(409, "An active competitor relationship already exists between these companies") result = _row_dict(row) await record_audit_event( pool, event_type="competitor_relationship.created", entity_type="competitor_relationship", entity_id=result["id"], data={ "company_a_id": company_id, "company_b_id": body.company_b_id, "relationship_type": body.relationship_type, "strength": body.strength, "bidirectional": body.bidirectional, "source": body.source, }, actor="operator", ) return result @router.get("/companies/{company_id}/competitors", response_model=List[CompetitorRelationship]) async def list_competitors(company_id: str, request: Request): """List active competitor relationships for a company, ordered by strength descending.""" pool = _get_pool(request) if not await _company_exists(pool, company_id): raise HTTPException(404, "Company not found") rows = await pool.fetch( """SELECT id, company_a_id, company_b_id, relationship_type, strength, bidirectional, source, active, created_at, updated_at FROM competitor_relationships WHERE (company_a_id = $1 OR company_b_id = $1) AND active = TRUE ORDER BY strength DESC""", company_id, ) return [_row_dict(r) for r in rows] @router.put("/companies/{company_id}/competitors/{relationship_id}", response_model=CompetitorRelationship) async def update_competitor(company_id: str, relationship_id: str, body: CompetitorRelationshipCreate, request: Request): """Update a competitor relationship with audit event recording previous state.""" pool = _get_pool(request) # Fetch existing relationship existing = await pool.fetchrow( """SELECT id, company_a_id, company_b_id, relationship_type, strength, bidirectional, source, active, created_at, updated_at FROM competitor_relationships WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2)""", relationship_id, company_id, ) if not existing: raise HTTPException(404, "Competitor relationship not found") previous_state = _row_dict(existing) row = await pool.fetchrow( """UPDATE competitor_relationships SET relationship_type = $2, strength = $3, bidirectional = $4, source = $5, updated_at = NOW() WHERE id = $1 RETURNING id, company_a_id, company_b_id, relationship_type, strength, bidirectional, source, active, created_at, updated_at""", relationship_id, body.relationship_type, body.strength, body.bidirectional, body.source, ) result = _row_dict(row) await record_audit_event( pool, event_type="competitor_relationship.updated", entity_type="competitor_relationship", entity_id=result["id"], data={ "previous_state": { "relationship_type": previous_state["relationship_type"], "strength": previous_state["strength"], "bidirectional": previous_state["bidirectional"], "source": previous_state["source"], }, "new_state": { "relationship_type": body.relationship_type, "strength": body.strength, "bidirectional": body.bidirectional, "source": body.source, }, }, actor="operator", ) return result @router.delete("/companies/{company_id}/competitors/{relationship_id}", status_code=200) async def delete_competitor(company_id: str, relationship_id: str, request: Request): """Soft-delete a competitor relationship (set active=False), preserve row.""" pool = _get_pool(request) row = await pool.fetchrow( """UPDATE competitor_relationships SET active = FALSE, updated_at = NOW() WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2) AND active = TRUE RETURNING id""", relationship_id, company_id, ) if not row: raise HTTPException(404, "Active competitor relationship not found") await record_audit_event( pool, event_type="competitor_relationship.deleted", entity_type="competitor_relationship", entity_id=str(row["id"]), data={"company_id": company_id, "soft_deleted": True}, actor="operator", ) return {"status": "deleted", "id": str(row["id"])}