114 lines
3.6 KiB
Python
114 lines
3.6 KiB
Python
"""Tests for the optional LLM thesis rewriting layer.
|
|
|
|
Tests prompt construction and the rewrite function's fallback behavior.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from services.recommendation.thesis_llm import (
|
|
THESIS_SYSTEM_PROMPT,
|
|
build_thesis_rewrite_prompt,
|
|
rewrite_thesis_with_llm,
|
|
)
|
|
from services.shared.config import OllamaConfig
|
|
from services.shared.schemas import (
|
|
TrendDirection,
|
|
TrendSummary,
|
|
TrendWindow,
|
|
)
|
|
|
|
|
|
def _make_summary(
|
|
ticker: str = "AAPL",
|
|
direction: TrendDirection = TrendDirection.BULLISH,
|
|
strength: float = 0.5,
|
|
confidence: float = 0.65,
|
|
contradiction: float = 0.1,
|
|
catalysts: list[str] | None = None,
|
|
risks: list[str] | None = None,
|
|
) -> TrendSummary:
|
|
return TrendSummary(
|
|
entity_type="company",
|
|
entity_id=ticker,
|
|
window=TrendWindow.SEVEN_DAY,
|
|
trend_direction=direction,
|
|
trend_strength=strength,
|
|
confidence=confidence,
|
|
top_supporting_evidence=["doc1", "doc2"],
|
|
top_opposing_evidence=[],
|
|
dominant_catalysts=catalysts or ["earnings"],
|
|
material_risks=risks or ["regulatory scrutiny"],
|
|
contradiction_score=contradiction,
|
|
)
|
|
|
|
|
|
DETERMINISTIC_THESIS = (
|
|
"AAPL shows a bullish trend over the 7d window with strength 0.50 "
|
|
"and confidence 0.65. Dominant catalysts: earnings. "
|
|
"Key risks: regulatory scrutiny. "
|
|
"Based on 2 supporting and 0 opposing evidence documents. "
|
|
"Recommendation: BUY (paper eligible)."
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prompt construction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_prompt_contains_deterministic_thesis():
|
|
summary = _make_summary()
|
|
prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary)
|
|
assert DETERMINISTIC_THESIS in prompts["user"]
|
|
|
|
|
|
def test_prompt_system_is_thesis_system_prompt():
|
|
summary = _make_summary()
|
|
prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary)
|
|
assert prompts["system"] == THESIS_SYSTEM_PROMPT
|
|
|
|
|
|
def test_prompt_includes_ticker_context():
|
|
summary = _make_summary(ticker="MSFT")
|
|
prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary)
|
|
assert "MSFT" in prompts["user"]
|
|
|
|
|
|
def test_prompt_includes_catalysts():
|
|
summary = _make_summary(catalysts=["product", "m_and_a"])
|
|
prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary)
|
|
assert "product" in prompts["user"]
|
|
|
|
|
|
def test_prompt_includes_risks():
|
|
summary = _make_summary(risks=["supply chain disruption"])
|
|
prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary)
|
|
assert "supply chain disruption" in prompts["user"]
|
|
|
|
|
|
def test_prompt_includes_trend_metrics():
|
|
summary = _make_summary(strength=0.72, confidence=0.88, contradiction=0.15)
|
|
prompts = build_thesis_rewrite_prompt(DETERMINISTIC_THESIS, summary)
|
|
assert "0.72" in prompts["user"]
|
|
assert "0.88" in prompts["user"]
|
|
assert "0.15" in prompts["user"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fallback behavior — LLM failure returns deterministic thesis
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rewrite_falls_back_on_connection_error():
|
|
"""When Ollama is unreachable, the deterministic thesis is returned."""
|
|
summary = _make_summary()
|
|
config = OllamaConfig(base_url="http://localhost:99999", timeout=2)
|
|
result = await rewrite_thesis_with_llm(
|
|
deterministic_thesis=DETERMINISTIC_THESIS,
|
|
summary=summary,
|
|
config=config,
|
|
)
|
|
assert result == DETERMINISTIC_THESIS
|