feat: competitive intelligence & historical pattern matching layer

This commit is contained in:
Celes Renata
2026-04-14 19:42:48 +00:00
parent b478022ba3
commit f7a11d14ea
203 changed files with 20155 additions and 97 deletions
+6
View File
@@ -12,6 +12,9 @@ from pydantic import BaseModel, field_validator
from services.shared.config import load_config
from services.shared.db import get_pg_pool
from services.shared.logging import setup_logging
from services.symbol_registry.exposure import router as exposure_router
from services.symbol_registry.competitors import router as competitors_router
from services.symbol_registry.competitor_inference import router as inference_router
config = load_config()
pool: Optional[asyncpg.Pool] = None
@@ -36,6 +39,9 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Stonks Oracle - Symbol Registry", lifespan=lifespan)
app.include_router(exposure_router)
app.include_router(competitors_router)
app.include_router(inference_router)
@app.get("/health")
@@ -0,0 +1,149 @@
"""Competitor auto-inference engine for the Symbol Registry API.
Identifies candidate competitors by sector/industry match and
document co-mention frequency, then upserts inferred relationships.
"""
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
router = APIRouter()
# --- Response Model ---
class CompetitorRelationship(BaseModel):
"""Response model for a competitor relationship."""
id: str
company_a_id: str
company_b_id: str
relationship_type: str
strength: float
bidirectional: bool
source: str
active: bool
created_at: datetime
updated_at: datetime
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
"""Convert asyncpg Record to dict with UUID→str coercion."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
async def infer_competitors(
pool: asyncpg.Pool, company_id: str
) -> list[dict[str, Any]]:
"""Infer competitor relationships based on sector/industry match and co-mentions.
1. Fetch target company's sector and industry.
2. Find other active companies with the same sector AND industry.
3. Count co-mentions in document_company_mentions for each candidate.
4. Compute strength = 0.3 * sector_match + 0.7 * normalized_co_mention_count.
5. Upsert relationships with source='inferred'.
Returns the list of upserted relationship rows.
"""
# Fetch target company
target = await pool.fetchrow(
"SELECT id, sector, industry FROM companies WHERE id = $1 AND active = TRUE",
company_id,
)
if not target:
raise HTTPException(404, "Company not found")
if target["sector"] is None or target["industry"] is None:
raise HTTPException(
400,
"Company must have both sector and industry defined for auto-inference",
)
sector = target["sector"]
industry = target["industry"]
# Find candidates: other active companies with same sector AND industry
candidates = await pool.fetch(
"""SELECT id FROM companies
WHERE sector = $1 AND industry = $2 AND active = TRUE AND id != $3""",
sector, industry, company_id,
)
if not candidates:
return []
candidate_ids = [r["id"] for r in candidates]
# Count co-mentions for each candidate
co_mention_rows = await pool.fetch(
"""SELECT dcm2.company_id AS candidate_id, COUNT(DISTINCT dcm1.document_id) AS co_count
FROM document_company_mentions dcm1
JOIN document_company_mentions dcm2
ON dcm1.document_id = dcm2.document_id
WHERE dcm1.company_id = $1
AND dcm2.company_id = ANY($2::uuid[])
GROUP BY dcm2.company_id""",
company_id, candidate_ids,
)
co_mention_map: dict[Any, int] = {}
for row in co_mention_rows:
co_mention_map[row["candidate_id"]] = row["co_count"]
# Normalize co-mention counts
max_count = max(co_mention_map.values()) if co_mention_map else 1
if max_count == 0:
max_count = 1
# Compute strength and upsert for each candidate
results: list[dict[str, Any]] = []
for cid in candidate_ids:
co_count = co_mention_map.get(cid, 0)
normalized = co_count / max_count
# sector_match is always 1.0 since we filter by sector+industry
strength = 0.3 * 1.0 + 0.7 * normalized
# Order IDs for the unique index: LEAST/GREATEST
a_id = min(company_id, str(cid), key=lambda x: x)
b_id = max(company_id, str(cid), key=lambda x: x)
row = await pool.fetchrow(
"""INSERT INTO competitor_relationships
(company_a_id, company_b_id, relationship_type, strength,
bidirectional, source)
VALUES ($1, $2, 'same_sector', $3, TRUE, 'inferred')
ON CONFLICT (LEAST(company_a_id, company_b_id), GREATEST(company_a_id, company_b_id))
WHERE active = TRUE
DO UPDATE SET strength = EXCLUDED.strength, updated_at = NOW()
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
a_id, b_id, strength,
)
results.append(_row_dict(row))
# Sort by strength descending before returning
results.sort(key=lambda r: r["strength"], reverse=True)
return results
@router.post(
"/companies/{company_id}/competitors/infer",
response_model=List[CompetitorRelationship],
)
async def infer_competitors_endpoint(company_id: str, request: Request):
"""Trigger auto-inference of competitor relationships for a company."""
pool = _get_pool(request)
return await infer_competitors(pool, company_id)
+226
View File
@@ -0,0 +1,226 @@
"""Competitor Relationship management endpoints for the Symbol Registry API."""
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
from services.shared.audit import record_audit_event
router = APIRouter()
# --- Valid values ---
VALID_RELATIONSHIP_TYPES = {"direct_rival", "same_sector", "overlapping_products", "supply_chain_adjacent"}
VALID_SOURCES = {"manual", "inferred"}
# --- Request/Response Models ---
class CompetitorRelationshipCreate(BaseModel):
"""Request body for creating a competitor relationship."""
company_b_id: str
relationship_type: str
strength: float = Field(default=0.5, ge=0, le=1)
bidirectional: bool = True
source: str = "manual"
@field_validator("relationship_type")
@classmethod
def validate_relationship_type(cls, v: str) -> str:
if v not in VALID_RELATIONSHIP_TYPES:
raise ValueError(f"relationship_type must be one of {VALID_RELATIONSHIP_TYPES}")
return v
@field_validator("source")
@classmethod
def validate_source(cls, v: str) -> str:
if v not in VALID_SOURCES:
raise ValueError(f"source must be one of {VALID_SOURCES}")
return v
class CompetitorRelationship(BaseModel):
"""Response model for a competitor relationship."""
id: str
company_a_id: str
company_b_id: str
relationship_type: str
strength: float
bidirectional: bool
source: str
active: bool
created_at: datetime
updated_at: datetime
def _row_dict(row: asyncpg.Record) -> dict[str, Any]:
"""Convert asyncpg Record to dict with UUID→str coercion."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
async def _company_exists(pool: asyncpg.Pool, company_id: str) -> bool:
"""Check if a company exists."""
return await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id) is not None
# --- Endpoints ---
@router.post("/companies/{company_id}/competitors", response_model=CompetitorRelationship, status_code=201)
async def create_competitor(company_id: str, body: CompetitorRelationshipCreate, request: Request):
"""Create a competitor relationship for a company."""
pool = _get_pool(request)
# Self-referencing check
if company_id == body.company_b_id:
raise HTTPException(400, "A company cannot be its own competitor")
# Check both companies exist
if not await _company_exists(pool, company_id):
raise HTTPException(404, "Company not found")
if not await _company_exists(pool, body.company_b_id):
raise HTTPException(404, "Competitor company not found")
try:
row = await pool.fetchrow(
"""INSERT INTO competitor_relationships
(company_a_id, company_b_id, relationship_type, strength, bidirectional, source)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
company_id, body.company_b_id, body.relationship_type,
body.strength, body.bidirectional, body.source,
)
except asyncpg.UniqueViolationError:
raise HTTPException(409, "An active competitor relationship already exists between these companies")
result = _row_dict(row)
await record_audit_event(
pool,
event_type="competitor_relationship.created",
entity_type="competitor_relationship",
entity_id=result["id"],
data={
"company_a_id": company_id,
"company_b_id": body.company_b_id,
"relationship_type": body.relationship_type,
"strength": body.strength,
"bidirectional": body.bidirectional,
"source": body.source,
},
actor="operator",
)
return result
@router.get("/companies/{company_id}/competitors", response_model=List[CompetitorRelationship])
async def list_competitors(company_id: str, request: Request):
"""List active competitor relationships for a company, ordered by strength descending."""
pool = _get_pool(request)
if not await _company_exists(pool, company_id):
raise HTTPException(404, "Company not found")
rows = await pool.fetch(
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at
FROM competitor_relationships
WHERE (company_a_id = $1 OR company_b_id = $1) AND active = TRUE
ORDER BY strength DESC""",
company_id,
)
return [_row_dict(r) for r in rows]
@router.put("/companies/{company_id}/competitors/{relationship_id}", response_model=CompetitorRelationship)
async def update_competitor(company_id: str, relationship_id: str, body: CompetitorRelationshipCreate, request: Request):
"""Update a competitor relationship with audit event recording previous state."""
pool = _get_pool(request)
# Fetch existing relationship
existing = await pool.fetchrow(
"""SELECT id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at
FROM competitor_relationships
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2)""",
relationship_id, company_id,
)
if not existing:
raise HTTPException(404, "Competitor relationship not found")
previous_state = _row_dict(existing)
row = await pool.fetchrow(
"""UPDATE competitor_relationships
SET relationship_type = $2, strength = $3, bidirectional = $4, source = $5, updated_at = NOW()
WHERE id = $1
RETURNING id, company_a_id, company_b_id, relationship_type, strength,
bidirectional, source, active, created_at, updated_at""",
relationship_id, body.relationship_type, body.strength, body.bidirectional, body.source,
)
result = _row_dict(row)
await record_audit_event(
pool,
event_type="competitor_relationship.updated",
entity_type="competitor_relationship",
entity_id=result["id"],
data={
"previous_state": {
"relationship_type": previous_state["relationship_type"],
"strength": previous_state["strength"],
"bidirectional": previous_state["bidirectional"],
"source": previous_state["source"],
},
"new_state": {
"relationship_type": body.relationship_type,
"strength": body.strength,
"bidirectional": body.bidirectional,
"source": body.source,
},
},
actor="operator",
)
return result
@router.delete("/companies/{company_id}/competitors/{relationship_id}", status_code=200)
async def delete_competitor(company_id: str, relationship_id: str, request: Request):
"""Soft-delete a competitor relationship (set active=False), preserve row."""
pool = _get_pool(request)
row = await pool.fetchrow(
"""UPDATE competitor_relationships
SET active = FALSE, updated_at = NOW()
WHERE id = $1 AND (company_a_id = $2 OR company_b_id = $2) AND active = TRUE
RETURNING id""",
relationship_id, company_id,
)
if not row:
raise HTTPException(404, "Active competitor relationship not found")
await record_audit_event(
pool,
event_type="competitor_relationship.deleted",
entity_type="competitor_relationship",
entity_id=str(row["id"]),
data={"company_id": company_id, "soft_deleted": True},
actor="operator",
)
return {"status": "deleted", "id": str(row["id"])}
+183
View File
@@ -0,0 +1,183 @@
"""Exposure Profile management endpoints for the Symbol Registry API."""
import json
import uuid
from datetime import datetime
from typing import Any, List
import asyncpg
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
router = APIRouter()
# --- Valid values ---
VALID_MARKET_POSITION_TIERS = {"global_leader", "multinational", "regional", "domestic"}
VALID_SOURCES = {"manual", "inferred"}
# --- Request/Response Models ---
class ExposureProfileCreate(BaseModel):
"""Request body for creating/updating an exposure profile."""
geographic_revenue_mix: dict[str, float] = Field(default_factory=dict)
supply_chain_regions: List[str] = Field(default_factory=list)
key_input_commodities: List[str] = Field(default_factory=list)
regulatory_jurisdictions: List[str] = Field(default_factory=list)
market_position_tier: str = "regional"
export_dependency_pct: float = 0.0
source: str = "manual"
confidence: float = 1.0
@field_validator("market_position_tier")
@classmethod
def validate_tier(cls, v: str) -> str:
if v not in VALID_MARKET_POSITION_TIERS:
raise ValueError(f"market_position_tier must be one of {VALID_MARKET_POSITION_TIERS}")
return v
@field_validator("source")
@classmethod
def validate_source(cls, v: str) -> str:
if v not in VALID_SOURCES:
raise ValueError(f"source must be one of {VALID_SOURCES}")
return v
@field_validator("export_dependency_pct", "confidence")
@classmethod
def validate_pct(cls, v: float) -> float:
if not 0.0 <= v <= 1.0:
raise ValueError("Value must be between 0.0 and 1.0")
return v
class ExposureProfileResponse(BaseModel):
"""Response model for an exposure profile."""
id: str
company_id: str
geographic_revenue_mix: dict[str, float]
supply_chain_regions: List[str]
key_input_commodities: List[str]
regulatory_jurisdictions: List[str]
market_position_tier: str
export_dependency_pct: float
source: str
confidence: float
version: int
active: bool
created_at: datetime
updated_at: datetime
def _row_to_profile(row: asyncpg.Record) -> dict[str, Any]:
"""Convert an asyncpg Record to a profile response dict."""
d = dict(row)
for k, v in d.items():
if isinstance(v, uuid.UUID):
d[k] = str(v)
# geographic_revenue_mix is stored as JSONB string, parse if needed
if isinstance(d.get("geographic_revenue_mix"), str):
d["geographic_revenue_mix"] = json.loads(d["geographic_revenue_mix"])
return d
def _get_pool(request: Request) -> asyncpg.Pool:
"""Get the database pool from the app module."""
from services.symbol_registry.app import pool
return pool
# --- Endpoints ---
@router.get("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
async def get_exposure_profile(company_id: str, request: Request):
"""Get the current active exposure profile for a company."""
pool = _get_pool(request)
row = await pool.fetchrow(
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at
FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC
LIMIT 1""",
company_id,
)
if not row:
raise HTTPException(404, "No active exposure profile found for this company")
return _row_to_profile(row)
@router.put("/companies/{company_id}/exposure", response_model=ExposureProfileResponse)
async def upsert_exposure_profile(company_id: str, body: ExposureProfileCreate, request: Request):
"""Create or update an exposure profile. Archives the previous active version."""
pool = _get_pool(request)
# Verify company exists
exists = await pool.fetchval("SELECT 1 FROM companies WHERE id = $1", company_id)
if not exists:
raise HTTPException(404, "Company not found")
async with pool.acquire() as conn:
async with conn.transaction():
# Fetch current active profile to get latest version
current = await conn.fetchrow(
"""SELECT version FROM exposure_profiles
WHERE company_id = $1 AND active = TRUE
ORDER BY version DESC LIMIT 1""",
company_id,
)
if current:
new_version = current["version"] + 1
# Archive the current active profile
await conn.execute(
"""UPDATE exposure_profiles
SET active = FALSE, updated_at = NOW()
WHERE company_id = $1 AND active = TRUE""",
company_id,
)
else:
new_version = 1
# Insert new profile
row = await conn.fetchrow(
"""INSERT INTO exposure_profiles
(company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, TRUE)
RETURNING id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at""",
company_id,
json.dumps(body.geographic_revenue_mix),
body.supply_chain_regions,
body.key_input_commodities,
body.regulatory_jurisdictions,
body.market_position_tier,
body.export_dependency_pct,
body.source,
body.confidence,
new_version,
)
return _row_to_profile(row)
@router.get("/companies/{company_id}/exposure/history", response_model=List[ExposureProfileResponse])
async def get_exposure_history(company_id: str, request: Request):
"""Get all exposure profile versions for a company, ordered by version descending."""
pool = _get_pool(request)
rows = await pool.fetch(
"""SELECT id, company_id, geographic_revenue_mix, supply_chain_regions,
key_input_commodities, regulatory_jurisdictions, market_position_tier,
export_dependency_pct, source, confidence, version, active,
created_at, updated_at
FROM exposure_profiles
WHERE company_id = $1
ORDER BY version DESC""",
company_id,
)
return [_row_to_profile(r) for r in rows]