phase 17: wire extractor→aggregation→recommendation queue chain, add company_id_map to extractor

This commit is contained in:
Celes Renata
2026-04-12 03:16:27 -07:00
parent 226cc3ff44
commit 012b973bb7
2 changed files with 59 additions and 18 deletions
+20 -7
View File
@@ -6,11 +6,16 @@ import json
import logging import logging
import asyncpg import asyncpg
import redis.asyncio as aioredis
from services.aggregation.worker import aggregate_company from services.aggregation.worker import aggregate_company
from services.shared.config import load_config from services.shared.config import load_config
from services.shared.logging import setup_logging from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import QUEUE_AGGREGATION, queue_key from services.shared.redis_keys import (
QUEUE_AGGREGATION,
QUEUE_RECOMMENDATION,
queue_key,
)
logger = logging.getLogger("aggregation_main") logger = logging.getLogger("aggregation_main")
@@ -20,11 +25,9 @@ async def main() -> None:
setup_logging("aggregation", level=config.log_level, json_output=config.json_logs) setup_logging("aggregation", level=config.log_level, json_output=config.json_logs)
pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8) pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8)
import redis.asyncio as aioredis
redis_client = aioredis.from_url(config.redis.url) redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_AGGREGATION) queue = queue_key(QUEUE_AGGREGATION)
rec_queue = queue_key(QUEUE_RECOMMENDATION)
logger.info("Aggregation worker started, polling %s", queue) logger.info("Aggregation worker started, polling %s", queue)
try: try:
@@ -34,8 +37,7 @@ async def main() -> None:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
payload = raw job = json.loads(raw)
job = json.loads(payload)
ticker = job.get("ticker", "") ticker = job.get("ticker", "")
logger.info("Processing aggregation job for %s", ticker) logger.info("Processing aggregation job for %s", ticker)
@@ -46,6 +48,17 @@ async def main() -> None:
"Aggregation complete for %s: %d windows", "Aggregation complete for %s: %d windows",
ticker, len(summaries), ticker, len(summaries),
) )
# Enqueue recommendation job for each window that produced a trend
for summary in summaries:
if summary.trend_strength > 0:
await redis_client.rpush(
rec_queue,
json.dumps(inject_trace_context({
"ticker": ticker,
"window": summary.window.value,
})),
)
except Exception: except Exception:
logger.exception("Aggregation failed for %s", ticker) logger.exception("Aggregation failed for %s", ticker)
finally: finally:
+39 -11
View File
@@ -2,20 +2,32 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import json
import logging import logging
import asyncpg import asyncpg
import redis.asyncio as aioredis
from minio import Minio from minio import Minio
from services.extractor.client import OllamaClient from services.extractor.client import OllamaClient
from services.extractor.worker import persist_extraction from services.extractor.worker import persist_extraction
from services.shared.config import load_config from services.shared.config import load_config
from services.shared.logging import setup_logging from services.shared.logging import inject_trace_context, setup_logging
from services.shared.redis_keys import QUEUE_EXTRACTION, queue_key from services.shared.redis_keys import (
QUEUE_AGGREGATION,
QUEUE_EXTRACTION,
queue_key,
)
logger = logging.getLogger("extractor_main") logger = logging.getLogger("extractor_main")
async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]:
"""Build a ticker -> company_id mapping from the companies table."""
rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE")
return {row["ticker"]: str(row["id"]) for row in rows}
async def main() -> None: async def main() -> None:
config = load_config() config = load_config()
setup_logging("extractor", level=config.log_level, json_output=config.json_logs) setup_logging("extractor", level=config.log_level, json_output=config.json_logs)
@@ -28,15 +40,15 @@ async def main() -> None:
secure=config.minio.secure, secure=config.minio.secure,
) )
ollama = OllamaClient(config.ollama) ollama = OllamaClient(config.ollama)
import json
import redis.asyncio as aioredis
redis_client = aioredis.from_url(config.redis.url) redis_client = aioredis.from_url(config.redis.url)
queue = queue_key(QUEUE_EXTRACTION) queue = queue_key(QUEUE_EXTRACTION)
agg_queue = queue_key(QUEUE_AGGREGATION)
logger.info("Extractor worker started, polling %s", queue) logger.info("Extractor worker started, polling %s", queue)
# Pre-load company ID map (refreshed periodically)
company_id_map = await _build_company_id_map(pool)
refresh_counter = 0
try: try:
while True: while True:
raw = await redis_client.lpop(queue) raw = await redis_client.lpop(queue)
@@ -44,24 +56,40 @@ async def main() -> None:
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
payload = raw job = json.loads(raw)
job = json.loads(payload)
document_id = job.get("document_id", "") document_id = job.get("document_id", "")
ticker = job.get("ticker", "") ticker = job.get("ticker", "")
text = job.get("text", "") or job.get("normalized_text", "") text = job.get("text", "") or job.get("normalized_text", "")
logger.info("Processing extraction job for doc %s / %s", document_id, ticker) logger.info("Processing extraction job for doc %s / %s", document_id, ticker)
# Refresh company map every 100 jobs
refresh_counter += 1
if refresh_counter % 100 == 0:
company_id_map = await _build_company_id_map(pool)
try: try:
extraction_response = await ollama.extract(text) extraction_response = await ollama.extract(
await persist_extraction( text,
document_id=document_id,
known_tickers=[ticker] if ticker else None,
)
result = await persist_extraction(
pool=pool, pool=pool,
minio_client=minio_client, minio_client=minio_client,
document_id=document_id, document_id=document_id,
ticker=ticker, ticker=ticker,
extraction_response=extraction_response, extraction_response=extraction_response,
company_id_map=company_id_map,
document_text_length=len(text), document_text_length=len(text),
) )
# Enqueue aggregation job for the ticker on success
if result.success and ticker:
await redis_client.rpush(
agg_queue,
json.dumps(inject_trace_context({"ticker": ticker})),
)
except Exception: except Exception:
logger.exception("Extraction failed for doc %s", document_id) logger.exception("Extraction failed for doc %s", document_id)
finally: finally: