feat: competitive intelligence & historical pattern matching layer
This commit is contained in:
+234
-4
@@ -9,13 +9,21 @@ import asyncpg
|
||||
import redis.asyncio as aioredis
|
||||
from minio import Minio
|
||||
|
||||
from services.aggregation.interpolation import (
|
||||
build_default_profile,
|
||||
compute_macro_impact_with_sector,
|
||||
filter_low_confidence_events,
|
||||
persist_macro_impact_records,
|
||||
)
|
||||
from services.extractor.client import OllamaClient
|
||||
from services.extractor.event_classifier import classify_global_event
|
||||
from services.extractor.worker import persist_extraction
|
||||
from services.shared.config import load_config
|
||||
from services.shared.logging import inject_trace_context, setup_logging
|
||||
from services.shared.redis_keys import (
|
||||
QUEUE_AGGREGATION,
|
||||
QUEUE_EXTRACTION,
|
||||
QUEUE_MACRO_CLASSIFICATION,
|
||||
queue_key,
|
||||
)
|
||||
|
||||
@@ -28,6 +36,198 @@ async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
|
||||
return {row["ticker"]: str(row["id"]) for row in rows}
|
||||
|
||||
|
||||
async def _fetch_document_type(pool: asyncpg.Pool, document_id: str) -> str | None:
|
||||
"""Fetch the document_type for a document."""
|
||||
row = await pool.fetchrow(
|
||||
"SELECT document_type FROM documents WHERE id = $1::uuid",
|
||||
document_id,
|
||||
)
|
||||
return row["document_type"] if row else None
|
||||
|
||||
|
||||
async def _fetch_company_info(pool: asyncpg.Pool) -> list[dict]:
|
||||
"""Fetch company info needed for exposure profile loading and interpolation."""
|
||||
rows = await pool.fetch(
|
||||
"""SELECT id, ticker, sector, industry, market_cap_bucket
|
||||
FROM companies WHERE active = TRUE"""
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
async def _load_exposure_profile(pool: asyncpg.Pool, company_id: str, sector: str, industry: str, market_cap_bucket: str):
|
||||
"""Load exposure profile for a company: manual > inferred > default.
|
||||
|
||||
Requirements: 4.1
|
||||
"""
|
||||
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
|
||||
|
||||
# Try manual or inferred profile from DB
|
||||
row = await pool.fetchrow(
|
||||
"""SELECT company_id, geographic_revenue_mix, supply_chain_regions,
|
||||
key_input_commodities, regulatory_jurisdictions, market_position_tier,
|
||||
export_dependency_pct, source, confidence, version
|
||||
FROM exposure_profiles
|
||||
WHERE company_id = $1 AND active = TRUE
|
||||
ORDER BY version DESC LIMIT 1""",
|
||||
company_id,
|
||||
)
|
||||
if row:
|
||||
geo_mix = row["geographic_revenue_mix"]
|
||||
if isinstance(geo_mix, str):
|
||||
geo_mix = json.loads(geo_mix)
|
||||
tier_val = row["market_position_tier"]
|
||||
try:
|
||||
tier = MarketPositionTier(tier_val)
|
||||
except ValueError:
|
||||
tier = MarketPositionTier.REGIONAL
|
||||
return ExposureProfileSchema(
|
||||
company_id=str(row["company_id"]),
|
||||
geographic_revenue_mix=geo_mix or {},
|
||||
supply_chain_regions=list(row["supply_chain_regions"] or []),
|
||||
key_input_commodities=list(row["key_input_commodities"] or []),
|
||||
regulatory_jurisdictions=list(row["regulatory_jurisdictions"] or []),
|
||||
market_position_tier=tier,
|
||||
export_dependency_pct=float(row["export_dependency_pct"] or 0.0),
|
||||
source=row["source"] or "manual",
|
||||
confidence=float(row["confidence"] or 1.0),
|
||||
version=row["version"] or 1,
|
||||
)
|
||||
|
||||
# Fall back to default profile
|
||||
profile = build_default_profile(sector or "", industry or "", market_cap_bucket or "small_cap")
|
||||
profile.company_id = str(company_id)
|
||||
return profile
|
||||
|
||||
|
||||
async def _compute_and_persist_macro_impacts(
|
||||
pool: asyncpg.Pool,
|
||||
event,
|
||||
companies: list[dict],
|
||||
confidence_threshold: float = 0.4,
|
||||
) -> list[str]:
|
||||
"""Compute MacroImpactRecords for all tracked companies and persist non-zero ones.
|
||||
|
||||
Requirements: 4.1, 4.5
|
||||
"""
|
||||
# Filter low-confidence events
|
||||
filtered = filter_low_confidence_events([event], confidence_threshold)
|
||||
if not filtered:
|
||||
logger.info("Event %s excluded: confidence %.3f below threshold %.3f",
|
||||
event.event_id, event.confidence, confidence_threshold)
|
||||
return []
|
||||
|
||||
records = []
|
||||
for company in companies:
|
||||
company_id = str(company["id"])
|
||||
ticker = company["ticker"]
|
||||
sector = company.get("sector") or ""
|
||||
industry = company.get("industry") or ""
|
||||
market_cap_bucket = company.get("market_cap_bucket") or "small_cap"
|
||||
|
||||
profile = await _load_exposure_profile(pool, company_id, sector, industry, market_cap_bucket)
|
||||
|
||||
record = compute_macro_impact_with_sector(event, profile, company_sector=sector)
|
||||
record.ticker = ticker
|
||||
record.company_id = company_id
|
||||
|
||||
if record.macro_impact_score > 0.0:
|
||||
records.append(record)
|
||||
|
||||
if records:
|
||||
ids = await persist_macro_impact_records(pool, records)
|
||||
logger.info(
|
||||
"Persisted %d macro impact records for event %s",
|
||||
len(ids), event.event_id,
|
||||
)
|
||||
return [r.ticker for r in records]
|
||||
|
||||
return []
|
||||
|
||||
|
||||
# Track consecutive macro classification failures for alerting (Requirement 10.4)
|
||||
_macro_consecutive_failures = 0
|
||||
_MACRO_FAILURE_ALERT_THRESHOLD = 3
|
||||
|
||||
|
||||
async def _process_macro_classification(
|
||||
*,
|
||||
pool: asyncpg.Pool,
|
||||
minio_client: Minio,
|
||||
ollama: OllamaClient,
|
||||
redis_client: aioredis.Redis,
|
||||
document_id: str,
|
||||
text: str,
|
||||
company_id_map: dict[str, str],
|
||||
confidence_threshold: float = 0.4,
|
||||
) -> None:
|
||||
"""Route a macro_event document to event classification, compute interpolation,
|
||||
and trigger aggregation for affected tickers.
|
||||
|
||||
Requirements: 2.1, 2.2, 2.3, 4.1, 4.5, 10.4
|
||||
"""
|
||||
global _macro_consecutive_failures
|
||||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||||
|
||||
try:
|
||||
event = await classify_global_event(
|
||||
normalized_text=text,
|
||||
document_id=document_id,
|
||||
ollama_client=ollama,
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
)
|
||||
logger.info(
|
||||
"Classified macro event %s for doc %s: severity=%s types=%s",
|
||||
event.event_id, document_id, event.severity, event.event_types,
|
||||
)
|
||||
|
||||
# Reset failure counter on success
|
||||
_macro_consecutive_failures = 0
|
||||
|
||||
# Load all tracked companies and compute macro impacts
|
||||
companies = await _fetch_company_info(pool)
|
||||
affected_tickers = await _compute_and_persist_macro_impacts(
|
||||
pool, event, companies, confidence_threshold,
|
||||
)
|
||||
|
||||
# Trigger aggregation for affected tickers (those with non-zero impact)
|
||||
enqueued_tickers = set()
|
||||
for ticker in affected_tickers:
|
||||
if ticker not in enqueued_tickers:
|
||||
await redis_client.rpush(
|
||||
agg_queue,
|
||||
json.dumps(inject_trace_context({
|
||||
"ticker": ticker,
|
||||
"macro_event_id": event.event_id,
|
||||
})),
|
||||
)
|
||||
enqueued_tickers.add(ticker)
|
||||
|
||||
logger.info(
|
||||
"Enqueued aggregation jobs for %d affected tickers after macro event %s",
|
||||
len(enqueued_tickers), event.event_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
_macro_consecutive_failures += 1
|
||||
logger.error("Macro event classification failed for doc %s: %s", document_id, e)
|
||||
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
|
||||
logger.critical(
|
||||
"ALERT: Sustained macro classification failures (%d consecutive). "
|
||||
"Continuing with company-only signals. Operator action required.",
|
||||
_macro_consecutive_failures,
|
||||
)
|
||||
except Exception:
|
||||
_macro_consecutive_failures += 1
|
||||
logger.exception("Unexpected error classifying macro event for doc %s", document_id)
|
||||
if _macro_consecutive_failures >= _MACRO_FAILURE_ALERT_THRESHOLD:
|
||||
logger.critical(
|
||||
"ALERT: Sustained macro classification failures (%d consecutive). "
|
||||
"Continuing with company-only signals. Operator action required.",
|
||||
_macro_consecutive_failures,
|
||||
)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
config = load_config()
|
||||
setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
|
||||
@@ -42,8 +242,10 @@ async def main() -> None:
|
||||
ollama = OllamaClient(config.ollama)
|
||||
redis_client = aioredis.from_url(config.redis.url)
|
||||
queue = queue_key(QUEUE_EXTRACTION)
|
||||
macro_queue = queue_key(QUEUE_MACRO_CLASSIFICATION)
|
||||
agg_queue = queue_key(QUEUE_AGGREGATION)
|
||||
logger.info("Extractor worker started, polling %s", queue)
|
||||
confidence_threshold = config.macro.macro_confidence_threshold
|
||||
logger.info("Extractor worker started, polling %s and %s", queue, macro_queue)
|
||||
|
||||
# Pre-load company ID map (refreshed periodically)
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
@@ -51,7 +253,13 @@ async def main() -> None:
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = await redis_client.lpop(queue)
|
||||
# Check macro classification queue first (priority)
|
||||
raw = await redis_client.lpop(macro_queue)
|
||||
is_macro_job = raw is not None
|
||||
|
||||
if raw is None:
|
||||
raw = await redis_client.lpop(queue)
|
||||
|
||||
if raw is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
@@ -80,13 +288,35 @@ async def main() -> None:
|
||||
except Exception as e:
|
||||
logger.warning("Could not fetch normalized text for doc %s: %s", document_id, e)
|
||||
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
# Refresh company map every 100 jobs
|
||||
refresh_counter += 1
|
||||
if refresh_counter % 100 == 0:
|
||||
company_id_map = await _build_company_id_map(pool)
|
||||
|
||||
# Route macro_event documents to event classification (Requirement 2.1)
|
||||
doc_type = None
|
||||
if is_macro_job:
|
||||
doc_type = "macro_event"
|
||||
else:
|
||||
doc_type = await _fetch_document_type(pool, document_id)
|
||||
|
||||
if doc_type == "macro_event":
|
||||
logger.info("Routing macro_event doc %s to event classifier", document_id)
|
||||
await _process_macro_classification(
|
||||
pool=pool,
|
||||
minio_client=minio_client,
|
||||
ollama=ollama,
|
||||
redis_client=redis_client,
|
||||
document_id=document_id,
|
||||
text=text,
|
||||
company_id_map=company_id_map,
|
||||
confidence_threshold=confidence_threshold,
|
||||
)
|
||||
continue
|
||||
|
||||
# Standard extraction pipeline for non-macro documents
|
||||
logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
|
||||
|
||||
try:
|
||||
# Pass all tracked tickers so the model can identify any mentioned companies
|
||||
all_tickers = list(company_id_map.keys()) if company_id_map else ([ticker] if ticker else None)
|
||||
|
||||
Reference in New Issue
Block a user