"""Unit tests for exposure profile auto-inference. Requirements: 9.1, 9.2, 9.3 """ from __future__ import annotations from services.extractor.exposure_inference import ( _compute_inference_confidence, _estimate_revenue_mix, _extract_commodities_from_text, _extract_regions_from_text, infer_exposure_profile, ) from services.shared.schemas import ( CatalystType, CompanyImpact, DocumentIntelligence, DocumentType, MarketPositionTier, Sentiment, ) # --------------------------------------------------------------------------- # Helper builders # --------------------------------------------------------------------------- def _make_filing( summary: str = "", key_facts: list[str] | None = None, macro_themes: list[str] | None = None, doc_type: str = "filing", ) -> DocumentIntelligence: companies = [] if key_facts: companies.append(CompanyImpact( ticker="TEST", company_name="Test Corp", relevance=0.8, sentiment=Sentiment.NEUTRAL, impact_score=0.5, impact_horizon="medium_term", catalyst_type=CatalystType.EARNINGS, key_facts=key_facts, )) return DocumentIntelligence( document_type=DocumentType(doc_type), summary=summary, companies=companies, macro_themes=macro_themes or [], confidence=0.7, ) # --------------------------------------------------------------------------- # Region extraction # --------------------------------------------------------------------------- class TestExtractRegions: def test_extracts_country_names(self): regions = _extract_regions_from_text("Revenue from China and Japan grew 15%") assert "CN" in regions assert "JP" in regions def test_extracts_region_codes(self): regions = _extract_regions_from_text("US operations expanded into EU markets") assert "US" in regions assert "EU" in regions def test_empty_text(self): assert _extract_regions_from_text("") == {} def test_no_regions(self): assert _extract_regions_from_text("quarterly earnings increased") == {} # --------------------------------------------------------------------------- # Commodity extraction # --------------------------------------------------------------------------- class TestExtractCommodities: def test_extracts_commodities(self): commodities = _extract_commodities_from_text( "Rising crude oil and copper prices impacted margins" ) assert "crude_oil" in commodities assert "copper" in commodities def test_semiconductor_variants(self): commodities = _extract_commodities_from_text("semiconductor shortage continues") assert "semiconductors" in commodities def test_empty_text(self): assert _extract_commodities_from_text("") == {} # --------------------------------------------------------------------------- # Revenue mix estimation # --------------------------------------------------------------------------- class TestEstimateRevenueMix: def test_normalizes_to_one(self): mix = _estimate_revenue_mix({"US": 3, "CN": 1, "JP": 1}) total = sum(mix.values()) assert abs(total - 1.0) < 0.01 def test_empty_counts(self): assert _estimate_revenue_mix({}) == {} def test_single_region(self): mix = _estimate_revenue_mix({"US": 5}) assert mix == {"US": 1.0} # --------------------------------------------------------------------------- # Confidence scoring # --------------------------------------------------------------------------- class TestComputeInferenceConfidence: def test_high_data_high_confidence(self): conf = _compute_inference_confidence(5, 5, 3, 25) assert conf > 0.5 def test_low_data_low_confidence(self): conf = _compute_inference_confidence(1, 1, 0, 2) assert conf < 0.5 def test_bounds(self): conf = _compute_inference_confidence(0, 0, 0, 0) assert 0.0 <= conf <= 1.0 conf = _compute_inference_confidence(100, 100, 100, 1000) assert 0.0 <= conf <= 1.0 # --------------------------------------------------------------------------- # Full inference # --------------------------------------------------------------------------- class TestInferExposureProfile: def test_infers_from_filings_with_geo_data(self): filings = [ _make_filing( summary="Revenue from United States was 60%, China 25%, and Japan 15%.", key_facts=["US revenue grew 10%", "China operations expanded"], ), ] profile = infer_exposure_profile(filings, "Information Technology", "Software", "large_cap") assert profile.source == "inferred" assert 0.0 <= profile.confidence <= 1.0 assert len(profile.geographic_revenue_mix) > 0 assert "US" in profile.geographic_revenue_mix def test_infers_commodities(self): filings = [ _make_filing( summary="Crude oil and natural gas prices affected our cost structure.", ), ] profile = infer_exposure_profile(filings, "Energy", "Oil & Gas", "mid_cap") assert profile.source == "inferred" assert "crude_oil" in profile.key_input_commodities def test_fallback_when_no_filings(self): profile = infer_exposure_profile([], "Energy", "Oil & Gas", "large_cap") assert profile.source == "inferred" assert len(profile.geographic_revenue_mix) > 0 def test_fallback_when_no_geo_or_commodity_data(self): filings = [ _make_filing(summary="Quarterly earnings were strong."), ] profile = infer_exposure_profile(filings, "Financials", "Banking", "mid_cap") # Should fall back to default since no geo/commodity data found assert profile.source == "inferred" assert len(profile.geographic_revenue_mix) > 0 def test_non_filing_documents_ignored(self): docs = [ _make_filing( summary="Revenue from China was 50%", doc_type="article", ), ] # Article type should be filtered out, falling back to default profile = infer_exposure_profile(docs, "Energy", "Oil & Gas", "small_cap") assert profile.source == "inferred" def test_market_cap_tier_mapping(self): filings = [ _make_filing(summary="US and European operations"), ] profile = infer_exposure_profile(filings, "Industrials", "Machinery", "large_cap") tier = profile.market_position_tier if isinstance(tier, MarketPositionTier): tier = tier.value assert tier == "global_leader" def test_confidence_in_bounds(self): filings = [ _make_filing(summary="Revenue from US, China, Japan, Germany, and India"), ] profile = infer_exposure_profile(filings, "Information Technology", "Software", "mid_cap") assert 0.0 <= profile.confidence <= 1.0