Files

150 lines
5.0 KiB
Python

"""Competitor auto-inference engine for the Symbol Registry API.
Identifies candidate competitors by sector/industry match and
document co-mention frequency, then upserts inferred relationships.
"""
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
router = APIRouter()
# --- Response Model ---
class CompetitorRelationship(BaseModel):
"""Response model for a competitor relationship."""
id: str
company_a_id: str
company_b_id: str
relationship_type: str
strength: float
bidirectional: bool
source: str
active: bool
created_at: datetime
updated_at: datetime
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
"""Convert asyncpg Record to dict with UUID→str coercion."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
async def infer_competitors(
pool: asyncpg.Pool, company_id: str
) -> list[dict[str, Any]]:
"""Infer competitor relationships based on sector/industry match and co-mentions.
1. Fetch target company's sector and industry.
2. Find other active companies with the same sector AND industry.
3. Count co-mentions in document_company_mentions for each candidate.
4. Compute strength = 0.3 * sector_match + 0.7 * normalized_co_mention_count.
5. Upsert relationships with source='inferred'.
Returns the list of upserted relationship rows.
"""
# Fetch target company
target = await pool.fetchrow(
"SELECT id, sector, industry FROM companies WHERE id = $1 AND active = TRUE",
company_id,
)
if not target:
raise HTTPException(404, "Company not found")
if target["sector"] is None or target["industry"] is None:
raise HTTPException(
400,
"Company must have both sector and industry defined for auto-inference",
)
sector = target["sector"]
industry = target["industry"]
# Find candidates: other active companies with same sector AND industry
candidates = await pool.fetch(
"""SELECT id FROM companies
WHERE sector = $1 AND industry = $2 AND active = TRUE AND id != $3""",
sector, industry, company_id,
)
if not candidates:
return []
candidate_ids = [r["id"] for r in candidates]
# Count co-mentions for each candidate
co_mention_rows = await pool.fetch(
"""SELECT dcm2.company_id AS candidate_id, COUNT(DISTINCT dcm1.document_id) AS co_count
FROM document_company_mentions dcm1
JOIN document_company_mentions dcm2
ON dcm1.document_id = dcm2.document_id
WHERE dcm1.company_id = $1
AND dcm2.company_id = ANY($2::uuid[])
GROUP BY dcm2.company_id""",
company_id, candidate_ids,
)
co_mention_map: dict[Any, int] = {}
for row in co_mention_rows:
co_mention_map[row["candidate_id"]] = row["co_count"]
# Normalize co-mention counts
max_count = max(co_mention_map.values()) if co_mention_map else 1
if max_count == 0:
max_count = 1
# Compute strength and upsert for each candidate
results: list[dict[str, Any]] = []
for cid in candidate_ids:
co_count = co_mention_map.get(cid, 0)
normalized = co_count / max_count
# sector_match is always 1.0 since we filter by sector+industry
strength = 0.3 * 1.0 + 0.7 * normalized
# Order IDs for the unique index: LEAST/GREATEST
a_id = min(company_id, str(cid), key=lambda x: x)
b_id = max(company_id, str(cid), key=lambda x: x)
row = await pool.fetchrow(
"""INSERT INTO competitor_relationships
(company_a_id, company_b_id, relationship_type, strength,
bidirectional, source)
VALUES ($1, $2, 'same_sector', $3, TRUE, 'inferred')
ON CONFLICT (LEAST(company_a_id, company_b_id), GREATEST(company_a_id, company_b_id))
WHERE active = TRUE
DO UPDATE SET strength = EXCLUDED.strength, updated_at = NOW()
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
a_id, b_id, strength,
)
results.append(_row_dict(row))
# Sort by strength descending before returning
results.sort(key=lambda r: r["strength"], reverse=True)
return results
@router.post(
"/companies/{company_id}/competitors/infer",
response_model=List[CompetitorRelationship],
)
async def infer_competitors_endpoint(company_id: str, request: Request):
"""Trigger auto-inference of competitor relationships for a company."""
pool = _get_pool(request)
return await infer_competitors(pool, company_id)