diff --git a/services/extractor/client.py b/services/extractor/client.py index 7d4d344..3d7737c 100644 --- a/services/extractor/client.py +++ b/services/extractor/client.py @@ -117,7 +117,9 @@ class OllamaClient: self._max_delay = config.retry_max_delay self._backoff_multiplier = config.retry_backoff_multiplier self._owns_client = http_client is None - self._http = http_client or httpx.AsyncClient(timeout=config.timeout) + self._http = http_client or httpx.AsyncClient( + timeout=httpx.Timeout(config.timeout, read=config.timeout), + ) async def close(self) -> None: """Close the underlying HTTP client if we own it.""" @@ -208,7 +210,13 @@ class OllamaClient: json_schema: dict[str, object], document_text: str = "", ) -> ExtractionAttempt: - """Make a single call to the Ollama /api/chat endpoint.""" + """Make a streaming call to Ollama with early-termination guardrails. + + Aborts the stream if: + - Total generated tokens exceed ``max_tokens`` + - No new chunk arrives within ``stall_timeout`` seconds + - Repetition loop detected in the last ``loop_window`` tokens + """ attempt = ExtractionAttempt(model=self._config.model) start = time.monotonic() @@ -219,19 +227,20 @@ class OllamaClient: {"role": "user", "content": prompts["user"]}, ], "format": json_schema, - "stream": False, + "stream": True, "think": False, } url = f"{self._config.base_url}/api/chat" logger.info( - "Ollama POST %s model=%s input_chars=%d", + "Ollama POST %s model=%s input_chars=%d (streaming)", url, self._config.model, len(prompts.get("user", "")), ) try: - resp = await self._http.post(url, json=payload) - _ = resp.raise_for_status() + req = self._http.build_request("POST", url, json=payload) + resp = await self._http.send(req, stream=True) + resp.raise_for_status() except httpx.TimeoutException: attempt.error = "timeout" attempt.duration_ms = int((time.monotonic() - start) * 1000) @@ -246,18 +255,67 @@ class OllamaClient: attempt.duration_ms = int((time.monotonic() - start) * 1000) return attempt + # Stream and accumulate with guardrails + chunks: list[str] = [] + token_count = 0 + last_chunk_time = time.monotonic() + abort_reason: str | None = None + + try: + async for line in resp.aiter_lines(): + if not line: + continue + try: + frame = json.loads(line) + except json.JSONDecodeError: + continue + + if frame.get("done"): + break + + msg = frame.get("message", {}) + token = msg.get("content", "") if isinstance(msg, dict) else "" + if not token: + continue + + chunks.append(token) + token_count += 1 + last_chunk_time = time.monotonic() + + # Guard: max tokens + if token_count > self._config.max_tokens: + abort_reason = f"max_tokens_exceeded ({token_count})" + break + + # Guard: repetition loop detection + if token_count >= self._config.loop_window: + window = chunks[-self._config.loop_window:] + unique_ratio = len(set(window)) / len(window) + if unique_ratio < self._config.loop_threshold: + abort_reason = f"repetition_loop (unique_ratio={unique_ratio:.2f})" + break + + # Guard: stall detection (check between chunks) + elapsed_since_last = time.monotonic() - last_chunk_time + if elapsed_since_last > self._config.stall_timeout: + abort_reason = "stall_timeout" + break + except httpx.ReadTimeout: + abort_reason = "read_timeout" + finally: + await resp.aclose() + attempt.duration_ms = int((time.monotonic() - start) * 1000) - # Parse the Ollama response envelope - try: - body: dict[str, object] = resp.json() - except json.JSONDecodeError: - attempt.error = "invalid_response_json" - attempt.raw_output = resp.text + if abort_reason: + logger.warning( + "Stream aborted after %d tokens: %s", token_count, abort_reason, + ) + attempt.error = abort_reason + attempt.raw_output = "".join(chunks) return attempt - msg = body.get("message") - content: str = msg.get("content", "") if isinstance(msg, dict) else "" + content = "".join(chunks) attempt.raw_output = content if not content: diff --git a/services/extractor/prompts.py b/services/extractor/prompts.py index c4ca28a..0303f2a 100644 --- a/services/extractor/prompts.py +++ b/services/extractor/prompts.py @@ -114,6 +114,8 @@ Fill these fields: For each company entry fill: ticker, company_name, relevance (0-1), sentiment, impact_score (0-1), impact_horizon, catalyst_type, key_facts (list), risks (list), evidence_spans (verbatim quotes from text). +catalyst_type MUST be exactly one of: earnings, product, legal, macro, supply_chain, m_and_a, rating_change, other. Use "other" if none of the specific categories fit. + --- DOCUMENT TEXT --- {document_text} --- END DOCUMENT TEXT ---""" diff --git a/services/extractor/schemas.py b/services/extractor/schemas.py index 547f2c6..dd21068 100644 --- a/services/extractor/schemas.py +++ b/services/extractor/schemas.py @@ -194,6 +194,39 @@ def validate_extraction( # Normalize model output before validation # --------------------------------------------------------------------------- +_CATALYST_ALIASES: dict[str, str] = { + "strategic pivot": "other", + "strategic": "other", + "restructuring": "other", + "partnership": "other", + "acquisition": "m_and_a", + "merger": "m_and_a", + "buyout": "m_and_a", + "lawsuit": "legal", + "regulation": "legal", + "regulatory": "legal", + "upgrade": "rating_change", + "downgrade": "rating_change", + "price target": "rating_change", + "inflation": "macro", + "interest rate": "macro", + "interest rates": "macro", + "tariff": "macro", + "tariffs": "macro", + "launch": "product", + "product launch": "product", + "revenue": "earnings", + "profit": "earnings", + "guidance": "earnings", + "supply": "supply_chain", + "shortage": "supply_chain", +} + +_VALID_CATALYSTS = frozenset({ + "earnings", "product", "legal", "macro", + "supply_chain", "m_and_a", "rating_change", "other", +}) + _HORIZON_MAP: dict[str, str] = { "long-term": "90d_plus", "long": "90d_plus", @@ -233,6 +266,11 @@ def _normalize_extraction_data(data: dict[str, Any]) -> dict[str, Any]: mapped = _HORIZON_MAP.get(horizon.lower().strip()) if mapped: comp["impact_horizon"] = mapped + # Map catalyst_type alternatives + cat = comp.get("catalyst_type", "") + if isinstance(cat, str) and cat.lower().strip() not in _VALID_CATALYSTS: + mapped_cat = _CATALYST_ALIASES.get(cat.lower().strip(), "other") + comp["catalyst_type"] = mapped_cat return data diff --git a/services/shared/config.py b/services/shared/config.py index d052ade..7f7c9b9 100644 --- a/services/shared/config.py +++ b/services/shared/config.py @@ -47,6 +47,10 @@ class OllamaConfig: retry_base_delay: float = 1.0 retry_max_delay: float = 10.0 retry_backoff_multiplier: float = 2.0 + max_tokens: int = 4096 + stall_timeout: float = 30.0 + loop_window: int = 64 + loop_threshold: float = 0.5 @dataclass