"""Competitor auto-inference engine for the Symbol Registry API. Identifies candidate competitors by sector/industry match and document co-mention frequency, then upserts inferred relationships. """ import uuid from datetime import datetime from typing import Any, List import asyncpg from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel router = APIRouter() # --- Response Model --- class CompetitorRelationship(BaseModel): """Response model for a competitor relationship.""" id: str company_a_id: str company_b_id: str relationship_type: str strength: float bidirectional: bool source: str active: bool created_at: datetime updated_at: datetime def _row_dict(row: asyncpg.Record) -> dict[str, Any]: """Convert asyncpg Record to dict with UUID→str coercion.""" d = dict(row) for k, v in d.items(): if isinstance(v, uuid.UUID): d[k] = str(v) return d def _get_pool(request: Request) -> asyncpg.Pool: """Get the database pool from the app module.""" from services.symbol_registry.app import pool return pool async def infer_competitors( pool: asyncpg.Pool, company_id: str ) -> list[dict[str, Any]]: """Infer competitor relationships based on sector/industry match and co-mentions. 1. Fetch target company's sector and industry. 2. Find other active companies with the same sector AND industry. 3. Count co-mentions in document_company_mentions for each candidate. 4. Compute strength = 0.3 * sector_match + 0.7 * normalized_co_mention_count. 5. Upsert relationships with source='inferred'. Returns the list of upserted relationship rows. """ # Fetch target company target = await pool.fetchrow( "SELECT id, sector, industry FROM companies WHERE id = $1 AND active = TRUE", company_id, ) if not target: raise HTTPException(404, "Company not found") if target["sector"] is None or target["industry"] is None: raise HTTPException( 400, "Company must have both sector and industry defined for auto-inference", ) sector = target["sector"] industry = target["industry"] # Find candidates: other active companies with same sector AND industry candidates = await pool.fetch( """SELECT id FROM companies WHERE sector = $1 AND industry = $2 AND active = TRUE AND id != $3""", sector, industry, company_id, ) if not candidates: return [] candidate_ids = [r["id"] for r in candidates] # Count co-mentions for each candidate co_mention_rows = await pool.fetch( """SELECT dcm2.company_id AS candidate_id, COUNT(DISTINCT dcm1.document_id) AS co_count FROM document_company_mentions dcm1 JOIN document_company_mentions dcm2 ON dcm1.document_id = dcm2.document_id WHERE dcm1.company_id = $1 AND dcm2.company_id = ANY($2::uuid[]) GROUP BY dcm2.company_id""", company_id, candidate_ids, ) co_mention_map: dict[Any, int] = {} for row in co_mention_rows: co_mention_map[row["candidate_id"]] = row["co_count"] # Normalize co-mention counts max_count = max(co_mention_map.values()) if co_mention_map else 1 if max_count == 0: max_count = 1 # Compute strength and upsert for each candidate results: list[dict[str, Any]] = [] for cid in candidate_ids: co_count = co_mention_map.get(cid, 0) normalized = co_count / max_count # sector_match is always 1.0 since we filter by sector+industry strength = 0.3 * 1.0 + 0.7 * normalized # Order IDs for the unique index: LEAST/GREATEST a_id = min(company_id, str(cid), key=lambda x: x) b_id = max(company_id, str(cid), key=lambda x: x) row = await pool.fetchrow( """INSERT INTO competitor_relationships (company_a_id, company_b_id, relationship_type, strength, bidirectional, source) VALUES ($1, $2, 'same_sector', $3, TRUE, 'inferred') ON CONFLICT (LEAST(company_a_id, company_b_id), GREATEST(company_a_id, company_b_id)) WHERE active = TRUE DO UPDATE SET strength = EXCLUDED.strength, updated_at = NOW() RETURNING id, company_a_id, company_b_id, relationship_type, strength, bidirectional, source, active, created_at, updated_at""", a_id, b_id, strength, ) results.append(_row_dict(row)) # Sort by strength descending before returning results.sort(key=lambda r: r["strength"], reverse=True) return results @router.post( "/companies/{company_id}/competitors/infer", response_model=List[CompetitorRelationship], ) async def infer_competitors_endpoint(company_id: str, request: Request): """Trigger auto-inference of competitor relationships for a company.""" pool = _get_pool(request) return await infer_competitors(pool, company_id)