diff --git a/services/aggregation/main.py b/services/aggregation/main.py index 522862b..7821b85 100644 --- a/services/aggregation/main.py +++ b/services/aggregation/main.py @@ -6,11 +6,16 @@ import json import logging import asyncpg +import redis.asyncio as aioredis from services.aggregation.worker import aggregate_company from services.shared.config import load_config -from services.shared.logging import setup_logging -from services.shared.redis_keys import QUEUE_AGGREGATION, queue_key +from services.shared.logging import inject_trace_context, setup_logging +from services.shared.redis_keys import ( + QUEUE_AGGREGATION, + QUEUE_RECOMMENDATION, + queue_key, +) logger = logging.getLogger("aggregation_main") @@ -20,11 +25,9 @@ async def main() -> None: setup_logging("aggregation", level=config.log_level, json_output=config.json_logs) pool = await asyncpg.create_pool(dsn=config.postgres.dsn, min_size=2, max_size=8) - - import redis.asyncio as aioredis - redis_client = aioredis.from_url(config.redis.url) queue = queue_key(QUEUE_AGGREGATION) + rec_queue = queue_key(QUEUE_RECOMMENDATION) logger.info("Aggregation worker started, polling %s", queue) try: @@ -34,8 +37,7 @@ async def main() -> None: await asyncio.sleep(1) continue - payload = raw - job = json.loads(payload) + job = json.loads(raw) ticker = job.get("ticker", "") logger.info("Processing aggregation job for %s", ticker) @@ -46,6 +48,17 @@ async def main() -> None: "Aggregation complete for %s: %d windows", ticker, len(summaries), ) + + # Enqueue recommendation job for each window that produced a trend + for summary in summaries: + if summary.trend_strength > 0: + await redis_client.rpush( + rec_queue, + json.dumps(inject_trace_context({ + "ticker": ticker, + "window": summary.window.value, + })), + ) except Exception: logger.exception("Aggregation failed for %s", ticker) finally: diff --git a/services/extractor/main.py b/services/extractor/main.py index eb2d1b3..a7e25c7 100644 --- a/services/extractor/main.py +++ b/services/extractor/main.py @@ -2,20 +2,32 @@ from __future__ import annotations import asyncio +import json import logging import asyncpg +import redis.asyncio as aioredis from minio import Minio from services.extractor.client import OllamaClient from services.extractor.worker import persist_extraction from services.shared.config import load_config -from services.shared.logging import setup_logging -from services.shared.redis_keys import QUEUE_EXTRACTION, queue_key +from services.shared.logging import inject_trace_context, setup_logging +from services.shared.redis_keys import ( + QUEUE_AGGREGATION, + QUEUE_EXTRACTION, + queue_key, +) logger = logging.getLogger("extractor_main") +async def _build_company_id_map(pool: asyncpg.Pool) -> dict[str, str]: + """Build a ticker -> company_id mapping from the companies table.""" + rows = await pool.fetch("SELECT id, ticker FROM companies WHERE active = TRUE") + return {row["ticker"]: str(row["id"]) for row in rows} + + async def main() -> None: config = load_config() setup_logging("extractor", level=config.log_level, json_output=config.json_logs) @@ -28,15 +40,15 @@ async def main() -> None: secure=config.minio.secure, ) ollama = OllamaClient(config.ollama) - - import json - - import redis.asyncio as aioredis - redis_client = aioredis.from_url(config.redis.url) queue = queue_key(QUEUE_EXTRACTION) + agg_queue = queue_key(QUEUE_AGGREGATION) logger.info("Extractor worker started, polling %s", queue) + # Pre-load company ID map (refreshed periodically) + company_id_map = await _build_company_id_map(pool) + refresh_counter = 0 + try: while True: raw = await redis_client.lpop(queue) @@ -44,24 +56,40 @@ async def main() -> None: await asyncio.sleep(1) continue - payload = raw - job = json.loads(payload) + job = json.loads(raw) document_id = job.get("document_id", "") ticker = job.get("ticker", "") text = job.get("text", "") or job.get("normalized_text", "") logger.info("Processing extraction job for doc %s / %s", document_id, ticker) + # Refresh company map every 100 jobs + refresh_counter += 1 + if refresh_counter % 100 == 0: + company_id_map = await _build_company_id_map(pool) + try: - extraction_response = await ollama.extract(text) - await persist_extraction( + extraction_response = await ollama.extract( + text, + document_id=document_id, + known_tickers=[ticker] if ticker else None, + ) + result = await persist_extraction( pool=pool, minio_client=minio_client, document_id=document_id, ticker=ticker, extraction_response=extraction_response, + company_id_map=company_id_map, document_text_length=len(text), ) + + # Enqueue aggregation job for the ticker on success + if result.success and ticker: + await redis_client.rpush( + agg_queue, + json.dumps(inject_trace_context({"ticker": ticker})), + ) except Exception: logger.exception("Extraction failed for doc %s", document_id) finally: