"""Tests for extraction prompt templates.""" import json from services.extractor.prompts import ( EXTRACTION_JSON_SCHEMA, PROMPT_VERSION, SCHEMA_VERSION, SYSTEM_PROMPT, build_extraction_prompt, get_json_schema, get_prompt_metadata, ) from services.shared.schemas import CatalystType, DocumentType, Sentiment def test_build_extraction_prompt_basic(): """Prompt contains system and user keys with document text embedded.""" result = build_extraction_prompt( document_text="Apple reported record Q4 earnings.", document_type=DocumentType.ARTICLE, ) assert "system" in result assert "user" in result assert "Apple reported record Q4 earnings." in result["user"] assert "DOCUMENT TEXT" in result["user"] def test_system_prompt_has_anti_hallucination_rules(): """System prompt includes key anti-hallucination instructions.""" assert "ONLY a single JSON object" in SYSTEM_PROMPT assert "No markdown fences" in SYSTEM_PROMPT assert "evidence_spans" in SYSTEM_PROMPT or "short" in SYSTEM_PROMPT assert "Use \"other\" for catalyst_type if unsure" in SYSTEM_PROMPT assert "required" in SYSTEM_PROMPT def test_build_prompt_includes_json_schema(): """User prompt embeds field instructions for structured output.""" result = build_extraction_prompt(document_text="test", document_type=DocumentType.ARTICLE) # The user prompt includes field-level instructions instead of the raw JSON schema assert "summary" in result["user"] assert "companies" in result["user"] assert "evidence_spans" in result["user"] def test_build_prompt_with_known_tickers(): """Known tickers are included as hints but with a warning not to force-include them.""" result = build_extraction_prompt( document_text="Some text", document_type=DocumentType.ARTICLE, known_tickers=["AAPL", "MSFT"], ) assert "AAPL" in result["user"] assert "MSFT" in result["user"] assert "Do NOT invent tickers not in the list above" in result["user"] def test_build_prompt_without_tickers(): """When no tickers are provided, no ticker hint appears.""" result = build_extraction_prompt(document_text="Some text", document_type=DocumentType.ARTICLE) assert "may be referenced" not in result["user"] def test_build_prompt_document_type_guidance(): """Each document type gets specific guidance in the prompt.""" for dtype in DocumentType: result = build_extraction_prompt(document_text="text", document_type=dtype) assert "Document type:" in result["user"] def test_build_prompt_filing_guidance(): """Filing documents get SEC-specific guidance.""" result = build_extraction_prompt(document_text="text", document_type=DocumentType.FILING) assert "regulatory filing" in result["user"] def test_build_prompt_transcript_guidance(): """Transcript documents get earnings-call-specific guidance.""" result = build_extraction_prompt(document_text="text", document_type=DocumentType.TRANSCRIPT) assert "forward-looking" in result["user"] def test_build_prompt_with_document_id(): """Document ID is included in the prompt when provided.""" result = build_extraction_prompt( document_text="text", document_type=DocumentType.ARTICLE, document_id="abc-123", ) assert "abc-123" in result["user"] def test_get_prompt_metadata(): """Metadata returns current prompt and schema versions.""" meta = get_prompt_metadata() assert meta["prompt_version"] == PROMPT_VERSION assert meta["schema_version"] == SCHEMA_VERSION def test_get_json_schema_is_valid(): """JSON schema has required top-level structure.""" schema = get_json_schema() assert schema["type"] == "object" assert "summary" in schema["required"] assert "companies" in schema["required"] assert "confidence" in schema["required"] def test_json_schema_enum_values_match_pydantic(): """Schema enum values match the Pydantic enum definitions.""" company_props = EXTRACTION_JSON_SCHEMA["properties"]["companies"]["items"]["properties"] assert set(company_props["sentiment"]["enum"]) == {s.value for s in Sentiment} assert set(company_props["catalyst_type"]["enum"]) == {c.value for c in CatalystType} def test_json_schema_is_serializable(): """Schema can be serialized to JSON without errors.""" serialized = json.dumps(EXTRACTION_JSON_SCHEMA) parsed = json.loads(serialized) assert parsed["type"] == "object"