Files
stonks-oracle/tests/test_extractor_prompts.py
T

121 lines
4.4 KiB
Python

"""Tests for extraction prompt templates."""
import json
from services.extractor.prompts import (
EXTRACTION_JSON_SCHEMA,
PROMPT_VERSION,
SCHEMA_VERSION,
SYSTEM_PROMPT,
build_extraction_prompt,
get_json_schema,
get_prompt_metadata,
)
from services.shared.schemas import CatalystType, DocumentType, Sentiment
def test_build_extraction_prompt_basic():
"""Prompt contains system and user keys with document text embedded."""
result = build_extraction_prompt(
document_text="Apple reported record Q4 earnings.",
document_type=DocumentType.ARTICLE,
)
assert "system" in result
assert "user" in result
assert "Apple reported record Q4 earnings." in result["user"]
assert "DOCUMENT TEXT" in result["user"]
def test_system_prompt_has_anti_hallucination_rules():
"""System prompt includes key anti-hallucination instructions."""
assert "NEVER fabricate" in SYSTEM_PROMPT
assert "NEVER infer" in SYSTEM_PROMPT
assert "verbatim quotes" in SYSTEM_PROMPT
assert "ONLY extract information explicitly stated" in SYSTEM_PROMPT
assert "insufficient_content" in SYSTEM_PROMPT
def test_build_prompt_includes_json_schema():
"""User prompt embeds the full JSON schema for structured output."""
result = build_extraction_prompt(document_text="test", document_type=DocumentType.ARTICLE)
# Schema should be serialized into the user prompt
assert '"summary"' in result["user"]
assert '"companies"' in result["user"]
assert '"evidence_spans"' in result["user"]
def test_build_prompt_with_known_tickers():
"""Known tickers are included as hints but with a warning not to force-include them."""
result = build_extraction_prompt(
document_text="Some text",
document_type=DocumentType.ARTICLE,
known_tickers=["AAPL", "MSFT"],
)
assert "AAPL" in result["user"]
assert "MSFT" in result["user"]
assert "Do NOT include a ticker just because" in result["user"]
def test_build_prompt_without_tickers():
"""When no tickers are provided, no ticker hint appears."""
result = build_extraction_prompt(document_text="Some text", document_type=DocumentType.ARTICLE)
assert "may be referenced" not in result["user"]
def test_build_prompt_document_type_guidance():
"""Each document type gets specific guidance in the prompt."""
for dtype in DocumentType:
result = build_extraction_prompt(document_text="text", document_type=dtype)
assert "Document type:" in result["user"]
def test_build_prompt_filing_guidance():
"""Filing documents get SEC-specific guidance."""
result = build_extraction_prompt(document_text="text", document_type=DocumentType.FILING)
assert "regulatory filing" in result["user"]
def test_build_prompt_transcript_guidance():
"""Transcript documents get earnings-call-specific guidance."""
result = build_extraction_prompt(document_text="text", document_type=DocumentType.TRANSCRIPT)
assert "forward-looking" in result["user"]
def test_build_prompt_with_document_id():
"""Document ID is included in the prompt when provided."""
result = build_extraction_prompt(
document_text="text",
document_type=DocumentType.ARTICLE,
document_id="abc-123",
)
assert "abc-123" in result["user"]
def test_get_prompt_metadata():
"""Metadata returns current prompt and schema versions."""
meta = get_prompt_metadata()
assert meta["prompt_version"] == PROMPT_VERSION
assert meta["schema_version"] == SCHEMA_VERSION
def test_get_json_schema_is_valid():
"""JSON schema has required top-level structure."""
schema = get_json_schema()
assert schema["type"] == "object"
assert "summary" in schema["required"]
assert "companies" in schema["required"]
assert "confidence" in schema["required"]
def test_json_schema_enum_values_match_pydantic():
"""Schema enum values match the Pydantic enum definitions."""
company_props = EXTRACTION_JSON_SCHEMA["properties"]["companies"]["items"]["properties"]
assert set(company_props["sentiment"]["enum"]) == {s.value for s in Sentiment}
assert set(company_props["catalyst_type"]["enum"]) == {c.value for c in CatalystType}
def test_json_schema_is_serializable():
"""Schema can be serialized to JSON without errors."""
serialized = json.dumps(EXTRACTION_JSON_SCHEMA)
parsed = json.loads(serialized)
assert parsed["type"] == "object"