150 lines
5.0 KiB
Python
150 lines
5.0 KiB
Python
"""Competitor auto-inference engine for the Symbol Registry API.
|
|
|
|
Identifies candidate competitors by sector/industry match and
|
|
document co-mention frequency, then upserts inferred relationships.
|
|
"""
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, List
|
|
|
|
import asyncpg
|
|
from fastapi import APIRouter, HTTPException, Request
|
|
from pydantic import BaseModel
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# --- Response Model ---
|
|
|
|
class CompetitorRelationship(BaseModel):
|
|
"""Response model for a competitor relationship."""
|
|
id: str
|
|
company_a_id: str
|
|
company_b_id: str
|
|
relationship_type: str
|
|
strength: float
|
|
bidirectional: bool
|
|
source: str
|
|
active: bool
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
|
|
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
|
|
"""Convert asyncpg Record to dict with UUID→str coercion."""
|
|
d = dict(row)
|
|
for k, v in d.items():
|
|
if isinstance(v, uuid.UUID):
|
|
d[k] = str(v)
|
|
return d
|
|
|
|
|
|
def _get_pool(request: Request) -> asyncpg.Pool:
|
|
"""Get the database pool from the app module."""
|
|
from services.symbol_registry.app import pool
|
|
return pool
|
|
|
|
|
|
async def infer_competitors(
|
|
pool: asyncpg.Pool, company_id: str
|
|
) -> list[dict[str, Any]]:
|
|
"""Infer competitor relationships based on sector/industry match and co-mentions.
|
|
|
|
1. Fetch target company's sector and industry.
|
|
2. Find other active companies with the same sector AND industry.
|
|
3. Count co-mentions in document_company_mentions for each candidate.
|
|
4. Compute strength = 0.3 * sector_match + 0.7 * normalized_co_mention_count.
|
|
5. Upsert relationships with source='inferred'.
|
|
|
|
Returns the list of upserted relationship rows.
|
|
"""
|
|
# Fetch target company
|
|
target = await pool.fetchrow(
|
|
"SELECT id, sector, industry FROM companies WHERE id = $1 AND active = TRUE",
|
|
company_id,
|
|
)
|
|
if not target:
|
|
raise HTTPException(404, "Company not found")
|
|
|
|
if target["sector"] is None or target["industry"] is None:
|
|
raise HTTPException(
|
|
400,
|
|
"Company must have both sector and industry defined for auto-inference",
|
|
)
|
|
|
|
sector = target["sector"]
|
|
industry = target["industry"]
|
|
|
|
# Find candidates: other active companies with same sector AND industry
|
|
candidates = await pool.fetch(
|
|
"""SELECT id FROM companies
|
|
WHERE sector = $1 AND industry = $2 AND active = TRUE AND id != $3""",
|
|
sector, industry, company_id,
|
|
)
|
|
|
|
if not candidates:
|
|
return []
|
|
|
|
candidate_ids = [r["id"] for r in candidates]
|
|
|
|
# Count co-mentions for each candidate
|
|
co_mention_rows = await pool.fetch(
|
|
"""SELECT dcm2.company_id AS candidate_id, COUNT(DISTINCT dcm1.document_id) AS co_count
|
|
FROM document_company_mentions dcm1
|
|
JOIN document_company_mentions dcm2
|
|
ON dcm1.document_id = dcm2.document_id
|
|
WHERE dcm1.company_id = $1
|
|
AND dcm2.company_id = ANY($2::uuid[])
|
|
GROUP BY dcm2.company_id""",
|
|
company_id, candidate_ids,
|
|
)
|
|
|
|
co_mention_map: dict[Any, int] = {}
|
|
for row in co_mention_rows:
|
|
co_mention_map[row["candidate_id"]] = row["co_count"]
|
|
|
|
# Normalize co-mention counts
|
|
max_count = max(co_mention_map.values()) if co_mention_map else 1
|
|
if max_count == 0:
|
|
max_count = 1
|
|
|
|
# Compute strength and upsert for each candidate
|
|
results: list[dict[str, Any]] = []
|
|
for cid in candidate_ids:
|
|
co_count = co_mention_map.get(cid, 0)
|
|
normalized = co_count / max_count
|
|
# sector_match is always 1.0 since we filter by sector+industry
|
|
strength = 0.3 * 1.0 + 0.7 * normalized
|
|
|
|
# Order IDs for the unique index: LEAST/GREATEST
|
|
a_id = min(company_id, str(cid), key=lambda x: x)
|
|
b_id = max(company_id, str(cid), key=lambda x: x)
|
|
|
|
row = await pool.fetchrow(
|
|
"""INSERT INTO competitor_relationships
|
|
(company_a_id, company_b_id, relationship_type, strength,
|
|
bidirectional, source)
|
|
VALUES ($1, $2, 'same_sector', $3, TRUE, 'inferred')
|
|
ON CONFLICT (LEAST(company_a_id, company_b_id), GREATEST(company_a_id, company_b_id))
|
|
WHERE active = TRUE
|
|
DO UPDATE SET strength = EXCLUDED.strength, updated_at = NOW()
|
|
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
|
|
bidirectional, source, active, created_at, updated_at""",
|
|
a_id, b_id, strength,
|
|
)
|
|
results.append(_row_dict(row))
|
|
|
|
# Sort by strength descending before returning
|
|
results.sort(key=lambda r: r["strength"], reverse=True)
|
|
return results
|
|
|
|
|
|
@router.post(
|
|
"/companies/{company_id}/competitors/infer",
|
|
response_model=List[CompetitorRelationship],
|
|
)
|
|
async def infer_competitors_endpoint(company_id: str, request: Request):
|
|
"""Trigger auto-inference of competitor relationships for a company."""
|
|
pool = _get_pool(request)
|
|
return await infer_competitors(pool, company_id)
|