"""Unit tests for the interpolation engine. Tests core scoring functions: overlap computation, resilience modifiers, macro impact scoring, default profile building, and direction determination. """ from __future__ import annotations import math from datetime import datetime, timezone import pytest from services.aggregation.interpolation import ( MacroImpactRecord, apply_resilience_modifier, build_default_profile, compute_commodity_overlap, compute_geographic_overlap, compute_macro_impact, compute_macro_impact_with_sector, compute_supply_chain_overlap, ) from services.extractor.event_classifier import GlobalEvent from services.shared.schemas import ExposureProfileSchema, MarketPositionTier # --------------------------------------------------------------------------- # compute_geographic_overlap # --------------------------------------------------------------------------- class TestComputeGeographicOverlap: def test_full_overlap(self): result = compute_geographic_overlap( ["US", "CN"], {"US": 0.6, "CN": 0.4}, ) assert math.isclose(result, 1.0, abs_tol=1e-6) def test_partial_overlap(self): result = compute_geographic_overlap( ["US"], {"US": 0.6, "CN": 0.4}, ) assert math.isclose(result, 0.6, abs_tol=1e-6) def test_no_overlap(self): result = compute_geographic_overlap( ["JP"], {"US": 0.6, "CN": 0.4}, ) assert result == 0.0 def test_empty_event_regions(self): assert compute_geographic_overlap([], {"US": 0.5}) == 0.0 def test_empty_revenue_mix(self): assert compute_geographic_overlap(["US"], {}) == 0.0 def test_case_insensitive(self): result = compute_geographic_overlap( ["us", "cn"], {"US": 0.6, "CN": 0.4}, ) assert math.isclose(result, 1.0, abs_tol=1e-6) def test_clamped_to_one(self): # Even if revenue mix sums > 1, result is clamped result = compute_geographic_overlap( ["US", "CN"], {"US": 0.7, "CN": 0.6}, ) assert result <= 1.0 # --------------------------------------------------------------------------- # compute_supply_chain_overlap # --------------------------------------------------------------------------- class TestComputeSupplyChainOverlap: def test_full_overlap(self): result = compute_supply_chain_overlap(["US", "CN"], ["US", "CN"]) assert result == 1.0 def test_partial_overlap(self): result = compute_supply_chain_overlap(["US"], ["US", "CN"]) assert math.isclose(result, 0.5, abs_tol=1e-6) def test_no_overlap(self): result = compute_supply_chain_overlap(["JP"], ["US", "CN"]) assert result == 0.0 def test_empty_event_regions(self): assert compute_supply_chain_overlap([], ["US"]) == 0.0 def test_empty_supply_regions(self): assert compute_supply_chain_overlap(["US"], []) == 0.0 def test_case_insensitive(self): result = compute_supply_chain_overlap(["us"], ["US", "CN"]) assert math.isclose(result, 0.5, abs_tol=1e-6) # --------------------------------------------------------------------------- # compute_commodity_overlap # --------------------------------------------------------------------------- class TestComputeCommodityOverlap: def test_full_overlap(self): result = compute_commodity_overlap( ["crude_oil", "natural_gas"], ["crude_oil", "natural_gas"], ) assert result == 1.0 def test_partial_overlap(self): result = compute_commodity_overlap( ["crude_oil"], ["crude_oil", "natural_gas"], ) assert math.isclose(result, 0.5, abs_tol=1e-6) def test_no_overlap(self): result = compute_commodity_overlap(["gold"], ["crude_oil"]) assert result == 0.0 def test_empty_event_commodities(self): assert compute_commodity_overlap([], ["crude_oil"]) == 0.0 def test_empty_company_commodities(self): assert compute_commodity_overlap(["crude_oil"], []) == 0.0 # --------------------------------------------------------------------------- # apply_resilience_modifier # --------------------------------------------------------------------------- class TestApplyResilienceModifier: def test_global_leader_dampens(self): result = apply_resilience_modifier(0.5, "global_leader", True) assert math.isclose(result, 0.35, abs_tol=1e-6) def test_domestic_amplifies(self): result = apply_resilience_modifier(0.5, "domestic", True) assert math.isclose(result, 0.6, abs_tol=1e-6) def test_regional_no_change(self): result = apply_resilience_modifier(0.5, "regional", True) assert math.isclose(result, 0.5, abs_tol=1e-6) def test_no_modifier_for_domestic_event(self): result = apply_resilience_modifier(0.5, "global_leader", False) assert math.isclose(result, 0.5, abs_tol=1e-6) def test_clamped_to_one(self): result = apply_resilience_modifier(0.9, "domestic", True) assert result <= 1.0 def test_clamped_to_zero(self): result = apply_resilience_modifier(0.0, "domestic", True) assert result >= 0.0 def test_tier_ordering_for_international(self): """global_leader <= multinational <= regional <= domestic.""" raw = 0.5 gl = apply_resilience_modifier(raw, "global_leader", True) mn = apply_resilience_modifier(raw, "multinational", True) rg = apply_resilience_modifier(raw, "regional", True) dm = apply_resilience_modifier(raw, "domestic", True) assert gl <= mn <= rg <= dm # --------------------------------------------------------------------------- # compute_macro_impact — zero overlap # --------------------------------------------------------------------------- class TestComputeMacroImpactZeroOverlap: def test_zero_overlap_returns_zero_score(self): event = GlobalEvent( event_id="evt-1", event_types=["supply_disruption"], severity="critical", affected_regions=["JP"], affected_sectors=["Energy"], affected_commodities=["gold"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-1", geographic_revenue_mix={"US": 1.0}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert record.macro_impact_score == 0.0 assert record.contributing_factors == [] assert record.confidence == 0.0 # --------------------------------------------------------------------------- # compute_macro_impact — basic scoring # --------------------------------------------------------------------------- class TestComputeMacroImpactScoring: def test_score_in_bounds(self): event = GlobalEvent( event_id="evt-2", event_types=["supply_disruption"], severity="critical", affected_regions=["US"], affected_sectors=["Energy"], affected_commodities=["crude_oil"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-2", geographic_revenue_mix={"US": 0.8}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert 0.0 <= record.macro_impact_score <= 1.0 assert record.macro_impact_score > 0.0 assert len(record.contributing_factors) > 0 def test_higher_severity_higher_score(self): """Critical severity should produce >= score than low severity.""" profile = ExposureProfileSchema( company_id="comp-3", geographic_revenue_mix={"US": 0.5}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) event_low = GlobalEvent( event_id="evt-low", event_types=["supply_disruption"], severity="low", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.9, ) event_critical = GlobalEvent( event_id="evt-crit", event_types=["supply_disruption"], severity="critical", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.9, ) low_record = compute_macro_impact(event_low, profile) crit_record = compute_macro_impact(event_critical, profile) assert crit_record.macro_impact_score >= low_record.macro_impact_score # --------------------------------------------------------------------------- # Mixed direction # --------------------------------------------------------------------------- class TestMixedDirection: def test_mixed_when_both_positive_and_negative(self): """demand_shift (positive) + supply_disruption (negative) → mixed.""" event = GlobalEvent( event_id="evt-mix", event_types=["demand_shift", "supply_disruption"], severity="high", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.8, ) profile = ExposureProfileSchema( company_id="comp-mix", geographic_revenue_mix={"US": 0.5}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert record.impact_direction == "mixed" # Both positive and negative factors should be in contributing_factors factors_str = " ".join(record.contributing_factors) assert "positive_types:" in factors_str assert "negative_types:" in factors_str def test_negative_only(self): event = GlobalEvent( event_id="evt-neg", event_types=["supply_disruption", "cost_increase"], severity="high", affected_regions=["US"], confidence=0.8, ) profile = ExposureProfileSchema( company_id="comp-neg", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert record.impact_direction == "negative" def test_positive_only(self): event = GlobalEvent( event_id="evt-pos", event_types=["demand_shift"], severity="moderate", affected_regions=["US"], confidence=0.8, ) profile = ExposureProfileSchema( company_id="comp-pos", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert record.impact_direction == "positive" # --------------------------------------------------------------------------- # compute_macro_impact_with_sector # --------------------------------------------------------------------------- class TestComputeMacroImpactWithSector: def test_sector_match_increases_score(self): event = GlobalEvent( event_id="evt-sec", event_types=["supply_disruption"], severity="high", affected_regions=["US"], affected_sectors=["Energy"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-sec", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) without_sector = compute_macro_impact_with_sector(event, profile, "") with_sector = compute_macro_impact_with_sector(event, profile, "Energy") assert with_sector.macro_impact_score >= without_sector.macro_impact_score def test_sector_no_match(self): event = GlobalEvent( event_id="evt-sec2", event_types=["supply_disruption"], severity="high", affected_regions=["US"], affected_sectors=["Energy"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-sec2", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact_with_sector(event, profile, "Financials") # No sector match, but still has geo overlap assert record.macro_impact_score > 0.0 factors_str = " ".join(record.contributing_factors) assert "sector_match" not in factors_str # --------------------------------------------------------------------------- # build_default_profile # --------------------------------------------------------------------------- class TestBuildDefaultProfile: @pytest.mark.parametrize("cap,expected_tier", [ ("large_cap", "global_leader"), ("mid_cap", "multinational"), ("small_cap", "regional"), ("micro_cap", "domestic"), ]) def test_market_cap_to_tier_mapping(self, cap, expected_tier): profile = build_default_profile("Energy", "Oil & Gas", cap) tier_val = profile.market_position_tier if isinstance(tier_val, MarketPositionTier): tier_val = tier_val.value assert tier_val == expected_tier def test_has_non_empty_geo_mix(self): profile = build_default_profile("Energy", "Oil & Gas", "large_cap") assert len(profile.geographic_revenue_mix) > 0 def test_source_is_inferred(self): profile = build_default_profile("Energy", "Oil & Gas", "mid_cap") assert profile.source == "inferred" def test_unknown_sector_uses_default_geo(self): profile = build_default_profile("UnknownSector", "Unknown", "small_cap") assert len(profile.geographic_revenue_mix) > 0 def test_energy_sector_has_commodities(self): profile = build_default_profile("Energy", "Oil & Gas", "large_cap") assert len(profile.key_input_commodities) > 0 assert "crude_oil" in profile.key_input_commodities # --------------------------------------------------------------------------- # MacroImpactRecord dataclass # --------------------------------------------------------------------------- class TestMacroImpactRecord: def test_defaults(self): record = MacroImpactRecord() assert record.event_id == "" assert record.macro_impact_score == 0.0 assert record.impact_direction == "neutral" assert record.contributing_factors == [] assert record.confidence == 0.5 assert record.computed_at is not None # --------------------------------------------------------------------------- # Low-confidence event exclusion (Requirements: 10.1) # --------------------------------------------------------------------------- from services.aggregation.interpolation import ( filter_low_confidence_events, apply_accelerated_decay, compute_standard_recency_decay, DEFAULT_CONFIDENCE_THRESHOLD, ACCELERATED_DECAY_MULTIPLIER, ) class TestFilterLowConfidenceEvents: def test_excludes_below_threshold(self): events = [ GlobalEvent(event_id="e1", confidence=0.3), GlobalEvent(event_id="e2", confidence=0.5), GlobalEvent(event_id="e3", confidence=0.1), ] result = filter_low_confidence_events(events, confidence_threshold=0.4) assert len(result) == 1 assert result[0].event_id == "e2" def test_includes_at_threshold(self): events = [GlobalEvent(event_id="e1", confidence=0.4)] result = filter_low_confidence_events(events, confidence_threshold=0.4) assert len(result) == 1 def test_empty_list(self): assert filter_low_confidence_events([], confidence_threshold=0.4) == [] def test_all_pass(self): events = [ GlobalEvent(event_id="e1", confidence=0.8), GlobalEvent(event_id="e2", confidence=0.9), ] result = filter_low_confidence_events(events, confidence_threshold=0.4) assert len(result) == 2 def test_all_excluded(self): events = [ GlobalEvent(event_id="e1", confidence=0.1), GlobalEvent(event_id="e2", confidence=0.2), ] result = filter_low_confidence_events(events, confidence_threshold=0.4) assert len(result) == 0 def test_default_threshold(self): assert DEFAULT_CONFIDENCE_THRESHOLD == 0.4 # --------------------------------------------------------------------------- # Accelerated decay for stale short-term events (Requirements: 10.2) # --------------------------------------------------------------------------- class TestAcceleratedDecay: def test_standard_decay_for_non_short_term(self): standard = compute_standard_recency_decay(72.0) accelerated = apply_accelerated_decay(72.0, "medium_term") assert accelerated == standard def test_standard_decay_for_young_short_term(self): """Short-term events within 48h get standard decay.""" standard = compute_standard_recency_decay(24.0) accelerated = apply_accelerated_decay(24.0, "short_term") assert accelerated == standard def test_accelerated_for_stale_short_term(self): """Short-term events older than 48h get accelerated decay.""" age = 72.0 standard = compute_standard_recency_decay(age) accelerated = apply_accelerated_decay(age, "short_term") assert accelerated < standard def test_accelerated_decay_multiplier(self): age = 72.0 standard = compute_standard_recency_decay(age) accelerated = apply_accelerated_decay(age, "short_term") assert abs(accelerated - standard * ACCELERATED_DECAY_MULTIPLIER) < 1e-9 def test_long_term_no_acceleration(self): standard = compute_standard_recency_decay(100.0) result = apply_accelerated_decay(100.0, "long_term") assert result == standard def test_zero_age(self): result = apply_accelerated_decay(0.0, "short_term") assert result == 1.0 def test_standard_decay_positive(self): result = compute_standard_recency_decay(168.0) assert 0.0 < result < 1.0