Files
stonks-oracle/services/symbol_registry/competitors.py
T

227 lines
8.0 KiB
Python

"""Competitor Relationship management endpoints for the Symbol Registry API."""
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
from services.shared.audit import record_audit_event
router = APIRouter()
# --- Valid values ---
VALID_RELATIONSHIP_TYPES = {"direct_rival", "same_sector", "overlapping_products", "supply_chain_adjacent"}
VALID_SOURCES = {"manual", "inferred"}
# --- Request/Response Models ---
class CompetitorRelationshipCreate(BaseModel):
"""Request body for creating a competitor relationship."""
company_b_id: str
relationship_type: str
strength: float = Field(default=0.5, ge=0, le=1)
bidirectional: bool = True
source: str = "manual"
@field_validator("relationship_type")
@classmethod
def validate_relationship_type(cls, v: str) -> str:
if v not in VALID_RELATIONSHIP_TYPES:
raise ValueError(f"relationship_type must be one of {VALID_RELATIONSHIP_TYPES}")
return v
@field_validator("source")
@classmethod
def validate_source(cls, v: str) -> str:
if v not in VALID_SOURCES:
raise ValueError(f"source must be one of {VALID_SOURCES}")
return v
class CompetitorRelationship(BaseModel):
"""Response model for a competitor relationship."""
id: str
company_a_id: str
company_b_id: str
relationship_type: str
strength: float
bidirectional: bool
source: str
active: bool
created_at: datetime
updated_at: datetime
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
"""Convert asyncpg Record to dict with UUID→str coercion."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
async def _company_exists(pool: asyncpg.Pool, company_id: str) -> bool:
"""Check if a company exists."""
return await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id) is not None
# --- Endpoints ---
@router.post("/companies/{company_id}/competitors", response_model=CompetitorRelationship, status_code=201)
async def create_competitor(company_id: str, body: CompetitorRelationshipCreate, request: Request):
"""Create a competitor relationship for a company."""
pool = _get_pool(request)
# Self-referencing check
if company_id == body.company_b_id:
raise HTTPException(400, "A company cannot be its own competitor")
# Check both companies exist
if not await _company_exists(pool, company_id):
raise HTTPException(404, "Company not found")
if not await _company_exists(pool, body.company_b_id):
raise HTTPException(404, "Competitor company not found")
try:
row = await pool.fetchrow(
"""INSERT INTO competitor_relationships
(company_a_id, company_b_id, relationship_type, strength, bidirectional, source)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
company_id, body.company_b_id, body.relationship_type,
body.strength, body.bidirectional, body.source,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, "An active competitor relationship already exists between these companies")
result = _row_dict(row)
await record_audit_event(
pool,
event_type="competitor_relationship.created",
entity_type="competitor_relationship",
entity_id=result["id"],
data={
"company_a_id": company_id,
"company_b_id": body.company_b_id,
"relationship_type": body.relationship_type,
"strength": body.strength,
"bidirectional": body.bidirectional,
"source": body.source,
},
actor="operator",
)
return result
@router.get("/companies/{company_id}/competitors", response_model=List[CompetitorRelationship])
async def list_competitors(company_id: str, request: Request):
"""List active competitor relationships for a company, ordered by strength descending."""
pool = _get_pool(request)
if not await _company_exists(pool, company_id):
raise HTTPException(404, "Company not found")
rows = await pool.fetch(
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at
FROM competitor_relationships
WHERE (company_a_id = $1 OR company_b_id = $1) AND active = TRUE
ORDER BY strength DESC""",
company_id,
)
return [_row_dict(r) for r in rows]
@router.put("/companies/{company_id}/competitors/{relationship_id}", response_model=CompetitorRelationship)
async def update_competitor(company_id: str, relationship_id: str, body: CompetitorRelationshipCreate, request: Request):
"""Update a competitor relationship with audit event recording previous state."""
pool = _get_pool(request)
# Fetch existing relationship
existing = await pool.fetchrow(
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at
FROM competitor_relationships
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2)""",
relationship_id, company_id,
)
if not existing:
raise HTTPException(404, "Competitor relationship not found")
previous_state = _row_dict(existing)
row = await pool.fetchrow(
"""UPDATE competitor_relationships
SET relationship_type = $2, strength = $3, bidirectional = $4, source = $5, updated_at = NOW()
WHERE id = $1
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
relationship_id, body.relationship_type, body.strength, body.bidirectional, body.source,
)
result = _row_dict(row)
await record_audit_event(
pool,
event_type="competitor_relationship.updated",
entity_type="competitor_relationship",
entity_id=result["id"],
data={
"previous_state": {
"relationship_type": previous_state["relationship_type"],
"strength": previous_state["strength"],
"bidirectional": previous_state["bidirectional"],
"source": previous_state["source"],
},
"new_state": {
"relationship_type": body.relationship_type,
"strength": body.strength,
"bidirectional": body.bidirectional,
"source": body.source,
},
},
actor="operator",
)
return result
@router.delete("/companies/{company_id}/competitors/{relationship_id}", status_code=200)
async def delete_competitor(company_id: str, relationship_id: str, request: Request):
"""Soft-delete a competitor relationship (set active=False), preserve row."""
pool = _get_pool(request)
row = await pool.fetchrow(
"""UPDATE competitor_relationships
SET active = FALSE, updated_at = NOW()
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2) AND active = TRUE
RETURNING id""",
relationship_id, company_id,
)
if not row:
raise HTTPException(404, "Active competitor relationship not found")
await record_audit_event(
pool,
event_type="competitor_relationship.deleted",
entity_type="competitor_relationship",
entity_id=str(row["id"]),
data={"company_id": company_id, "soft_deleted": True},
actor="operator",
)
return {"status": "deleted", "id": str(row["id"])}