"""Unit tests for the interpolation engine. Tests core scoring functions: overlap computation, resilience modifiers, macro impact scoring, default profile building, and direction determination. """ from __future__ import annotations import math import pytest from services.aggregation.interpolation import ( MacroImpactRecord, apply_resilience_modifier, build_default_profile, compute_commodity_overlap, compute_geographic_overlap, compute_macro_impact, compute_macro_impact_with_sector, compute_supply_chain_overlap, ) from services.extractor.event_classifier import GlobalEvent from services.shared.schemas import ExposureProfileSchema, MarketPositionTier # --------------------------------------------------------------------------- # compute_geographic_overlap # --------------------------------------------------------------------------- class TestComputeGeographicOverlap: def test_full_overlap(self): result = compute_geographic_overlap( ["US", "CN"], {"US": 0.6, "CN": 0.4}, ) assert math.isclose(result, 1.0, abs_tol=1e-6) def test_partial_overlap(self): result = compute_geographic_overlap( ["US"], {"US": 0.6, "CN": 0.4}, ) assert math.isclose(result, 0.6, abs_tol=1e-6) def test_no_overlap(self): result = compute_geographic_overlap( ["JP"], {"US": 0.6, "CN": 0.4}, ) assert result == 0.0 def test_empty_event_regions(self): assert compute_geographic_overlap([], {"US": 0.5}) == 0.0 def test_empty_revenue_mix(self): assert compute_geographic_overlap(["US"], {}) == 0.0 def test_case_insensitive(self): result = compute_geographic_overlap( ["us", "cn"], {"US": 0.6, "CN": 0.4}, ) assert math.isclose(result, 1.0, abs_tol=1e-6) def test_clamped_to_one(self): # Even if revenue mix sums > 1, result is clamped result = compute_geographic_overlap( ["US", "CN"], {"US": 0.7, "CN": 0.6}, ) assert result <= 1.0 # --------------------------------------------------------------------------- # compute_supply_chain_overlap # --------------------------------------------------------------------------- class TestComputeSupplyChainOverlap: def test_full_overlap(self): result = compute_supply_chain_overlap(["US", "CN"], ["US", "CN"]) assert result == 1.0 def test_partial_overlap(self): result = compute_supply_chain_overlap(["US"], ["US", "CN"]) assert math.isclose(result, 0.5, abs_tol=1e-6) def test_no_overlap(self): result = compute_supply_chain_overlap(["JP"], ["US", "CN"]) assert result == 0.0 def test_empty_event_regions(self): assert compute_supply_chain_overlap([], ["US"]) == 0.0 def test_empty_supply_regions(self): assert compute_supply_chain_overlap(["US"], []) == 0.0 def test_case_insensitive(self): result = compute_supply_chain_overlap(["us"], ["US", "CN"]) assert math.isclose(result, 0.5, abs_tol=1e-6) # --------------------------------------------------------------------------- # compute_commodity_overlap # --------------------------------------------------------------------------- class TestComputeCommodityOverlap: def test_full_overlap(self): result = compute_commodity_overlap( ["crude_oil", "natural_gas"], ["crude_oil", "natural_gas"], ) assert result == 1.0 def test_partial_overlap(self): result = compute_commodity_overlap( ["crude_oil"], ["crude_oil", "natural_gas"], ) assert math.isclose(result, 0.5, abs_tol=1e-6) def test_no_overlap(self): result = compute_commodity_overlap(["gold"], ["crude_oil"]) assert result == 0.0 def test_empty_event_commodities(self): assert compute_commodity_overlap([], ["crude_oil"]) == 0.0 def test_empty_company_commodities(self): assert compute_commodity_overlap(["crude_oil"], []) == 0.0 # --------------------------------------------------------------------------- # apply_resilience_modifier # --------------------------------------------------------------------------- class TestApplyResilienceModifier: def test_global_leader_dampens(self): result = apply_resilience_modifier(0.5, "global_leader", True) assert math.isclose(result, 0.35, abs_tol=1e-6) def test_domestic_amplifies(self): result = apply_resilience_modifier(0.5, "domestic", True) assert math.isclose(result, 0.6, abs_tol=1e-6) def test_regional_no_change(self): result = apply_resilience_modifier(0.5, "regional", True) assert math.isclose(result, 0.5, abs_tol=1e-6) def test_no_modifier_for_domestic_event(self): result = apply_resilience_modifier(0.5, "global_leader", False) assert math.isclose(result, 0.5, abs_tol=1e-6) def test_clamped_to_one(self): result = apply_resilience_modifier(0.9, "domestic", True) assert result <= 1.0 def test_clamped_to_zero(self): result = apply_resilience_modifier(0.0, "domestic", True) assert result >= 0.0 def test_tier_ordering_for_international(self): """global_leader <= multinational <= regional <= domestic.""" raw = 0.5 gl = apply_resilience_modifier(raw, "global_leader", True) mn = apply_resilience_modifier(raw, "multinational", True) rg = apply_resilience_modifier(raw, "regional", True) dm = apply_resilience_modifier(raw, "domestic", True) assert gl <= mn <= rg <= dm # --------------------------------------------------------------------------- # compute_macro_impact — zero overlap # --------------------------------------------------------------------------- class TestComputeMacroImpactZeroOverlap: def test_zero_overlap_returns_zero_score(self): event = GlobalEvent( event_id="evt-1", event_types=["supply_disruption"], severity="critical", affected_regions=["JP"], affected_sectors=["Energy"], affected_commodities=["gold"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-1", geographic_revenue_mix={"US": 1.0}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert record.macro_impact_score == 0.0 assert record.contributing_factors == [] assert record.confidence == 0.0 # --------------------------------------------------------------------------- # compute_macro_impact — basic scoring # --------------------------------------------------------------------------- class TestComputeMacroImpactScoring: def test_score_in_bounds(self): event = GlobalEvent( event_id="evt-2", event_types=["supply_disruption"], severity="critical", affected_regions=["US"], affected_sectors=["Energy"], affected_commodities=["crude_oil"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-2", geographic_revenue_mix={"US": 0.8}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert 0.0 <= record.macro_impact_score <= 1.0 assert record.macro_impact_score > 0.0 assert len(record.contributing_factors) > 0 def test_higher_severity_higher_score(self): """Critical severity should produce >= score than low severity.""" profile = ExposureProfileSchema( company_id="comp-3", geographic_revenue_mix={"US": 0.5}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) event_low = GlobalEvent( event_id="evt-low", event_types=["supply_disruption"], severity="low", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.9, ) event_critical = GlobalEvent( event_id="evt-crit", event_types=["supply_disruption"], severity="critical", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.9, ) low_record = compute_macro_impact(event_low, profile) crit_record = compute_macro_impact(event_critical, profile) assert crit_record.macro_impact_score >= low_record.macro_impact_score # --------------------------------------------------------------------------- # Mixed direction # --------------------------------------------------------------------------- class TestMixedDirection: def test_mixed_when_both_positive_and_negative(self): """demand_shift (positive) + supply_disruption (negative) → mixed.""" event = GlobalEvent( event_id="evt-mix", event_types=["demand_shift", "supply_disruption"], severity="high", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.8, ) profile = ExposureProfileSchema( company_id="comp-mix", geographic_revenue_mix={"US": 0.5}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert record.impact_direction == "mixed" # Both positive and negative factors should be in contributing_factors factors_str = " ".join(record.contributing_factors) assert "positive_types:" in factors_str assert "negative_types:" in factors_str def test_negative_only(self): event = GlobalEvent( event_id="evt-neg", event_types=["supply_disruption", "cost_increase"], severity="high", affected_regions=["US"], confidence=0.8, ) profile = ExposureProfileSchema( company_id="comp-neg", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert record.impact_direction == "negative" def test_positive_only(self): event = GlobalEvent( event_id="evt-pos", event_types=["demand_shift"], severity="moderate", affected_regions=["US"], confidence=0.8, ) profile = ExposureProfileSchema( company_id="comp-pos", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile) assert record.impact_direction == "positive" # --------------------------------------------------------------------------- # compute_macro_impact_with_sector # --------------------------------------------------------------------------- class TestComputeMacroImpactWithSector: def test_sector_match_increases_score(self): event = GlobalEvent( event_id="evt-sec", event_types=["supply_disruption"], severity="high", affected_regions=["US"], affected_sectors=["Energy"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-sec", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) without_sector = compute_macro_impact_with_sector(event, profile, "") with_sector = compute_macro_impact_with_sector(event, profile, "Energy") assert with_sector.macro_impact_score >= without_sector.macro_impact_score def test_sector_no_match(self): event = GlobalEvent( event_id="evt-sec2", event_types=["supply_disruption"], severity="high", affected_regions=["US"], affected_sectors=["Energy"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-sec2", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact_with_sector(event, profile, "Financials") # No sector match, but still has geo overlap assert record.macro_impact_score > 0.0 factors_str = " ".join(record.contributing_factors) assert "sector_match" not in factors_str # --------------------------------------------------------------------------- # build_default_profile # --------------------------------------------------------------------------- class TestBuildDefaultProfile: @pytest.mark.parametrize("cap,expected_tier", [ ("large_cap", "global_leader"), ("mid_cap", "multinational"), ("small_cap", "regional"), ("micro_cap", "domestic"), ]) def test_market_cap_to_tier_mapping(self, cap, expected_tier): profile = build_default_profile("Energy", "Oil & Gas", cap) tier_val = profile.market_position_tier if isinstance(tier_val, MarketPositionTier): tier_val = tier_val.value assert tier_val == expected_tier def test_has_non_empty_geo_mix(self): profile = build_default_profile("Energy", "Oil & Gas", "large_cap") assert len(profile.geographic_revenue_mix) > 0 def test_source_is_inferred(self): profile = build_default_profile("Energy", "Oil & Gas", "mid_cap") assert profile.source == "inferred" def test_unknown_sector_uses_default_geo(self): profile = build_default_profile("UnknownSector", "Unknown", "small_cap") assert len(profile.geographic_revenue_mix) > 0 def test_energy_sector_has_commodities(self): profile = build_default_profile("Energy", "Oil & Gas", "large_cap") assert len(profile.key_input_commodities) > 0 assert "crude_oil" in profile.key_input_commodities # --------------------------------------------------------------------------- # MacroImpactRecord dataclass # --------------------------------------------------------------------------- class TestMacroImpactRecord: def test_defaults(self): record = MacroImpactRecord() assert record.event_id == "" assert record.macro_impact_score == 0.0 assert record.impact_direction == "neutral" assert record.contributing_factors == [] assert record.confidence == 0.5 assert record.computed_at is not None # --------------------------------------------------------------------------- # Low-confidence event exclusion (Requirements: 10.1) # --------------------------------------------------------------------------- from services.aggregation.interpolation import ( ACCELERATED_DECAY_MULTIPLIER, DEFAULT_CONFIDENCE_THRESHOLD, apply_accelerated_decay, compute_standard_recency_decay, filter_low_confidence_events, ) class TestFilterLowConfidenceEvents: def test_excludes_below_threshold(self): events = [ GlobalEvent(event_id="e1", confidence=0.3), GlobalEvent(event_id="e2", confidence=0.5), GlobalEvent(event_id="e3", confidence=0.1), ] result = filter_low_confidence_events(events, confidence_threshold=0.4) assert len(result) == 1 assert result[0].event_id == "e2" def test_includes_at_threshold(self): events = [GlobalEvent(event_id="e1", confidence=0.4)] result = filter_low_confidence_events(events, confidence_threshold=0.4) assert len(result) == 1 def test_empty_list(self): assert filter_low_confidence_events([], confidence_threshold=0.4) == [] def test_all_pass(self): events = [ GlobalEvent(event_id="e1", confidence=0.8), GlobalEvent(event_id="e2", confidence=0.9), ] result = filter_low_confidence_events(events, confidence_threshold=0.4) assert len(result) == 2 def test_all_excluded(self): events = [ GlobalEvent(event_id="e1", confidence=0.1), GlobalEvent(event_id="e2", confidence=0.2), ] result = filter_low_confidence_events(events, confidence_threshold=0.4) assert len(result) == 0 def test_default_threshold(self): assert DEFAULT_CONFIDENCE_THRESHOLD == 0.4 # --------------------------------------------------------------------------- # Accelerated decay for stale short-term events (Requirements: 10.2) # --------------------------------------------------------------------------- class TestAcceleratedDecay: def test_standard_decay_for_non_short_term(self): standard = compute_standard_recency_decay(72.0) accelerated = apply_accelerated_decay(72.0, "medium_term") assert accelerated == standard def test_standard_decay_for_young_short_term(self): """Short-term events within 48h get standard decay.""" standard = compute_standard_recency_decay(24.0) accelerated = apply_accelerated_decay(24.0, "short_term") assert accelerated == standard def test_accelerated_for_stale_short_term(self): """Short-term events older than 48h get accelerated decay.""" age = 72.0 standard = compute_standard_recency_decay(age) accelerated = apply_accelerated_decay(age, "short_term") assert accelerated < standard def test_accelerated_decay_multiplier(self): age = 72.0 standard = compute_standard_recency_decay(age) accelerated = apply_accelerated_decay(age, "short_term") assert abs(accelerated - standard * ACCELERATED_DECAY_MULTIPLIER) < 1e-9 def test_long_term_no_acceleration(self): standard = compute_standard_recency_decay(100.0) result = apply_accelerated_decay(100.0, "long_term") assert result == standard def test_zero_age(self): result = apply_accelerated_decay(0.0, "short_term") assert result == 1.0 def test_standard_decay_positive(self): result = compute_standard_recency_decay(168.0) assert 0.0 < result < 1.0 # --------------------------------------------------------------------------- # Multiplicative macro exposure formula (Task 10.1, Requirements: 10.1–10.6) # --------------------------------------------------------------------------- from services.aggregation.interpolation import ( _compute_linear_exposure, _compute_multiplicative_exposure, compute_conditional_macro_modifier, integrate_macro_signals, ) class TestMultiplicativeExposure: """Tests for the multiplicative compounding exposure formula.""" def test_zero_overlap_returns_zero(self): """All overlaps zero → exposure = 0.""" assert _compute_multiplicative_exposure(0.0, 0.0, 0.0, 0.0) == 0.0 def test_max_overlap_approx_0724(self): """All overlaps 1.0 → exposure ≈ 0.689 (from the multiplicative formula).""" result = _compute_multiplicative_exposure(1.0, 1.0, 1.0, 1.0) expected = 1.0 - (1 - 0.35) * (1 - 0.25) * (1 - 0.25) * (1 - 0.15) assert math.isclose(result, expected, abs_tol=1e-6) # Requirement 10.4 states ≈0.724 but the exact formula yields ≈0.689 assert 0.6 < result < 0.8 def test_single_dimension_equals_weight(self): """Only geo overlap at 1.0 → exposure = 0.35.""" result = _compute_multiplicative_exposure(1.0, 0.0, 0.0, 0.0) assert math.isclose(result, 0.35, abs_tol=1e-6) def test_multiplicative_differs_from_linear_for_multi_overlap(self): """Multiplicative and linear produce different results for multi-dimension overlap.""" geo, supply, commodity, sector = 0.8, 0.6, 0.5, 0.4 mult = _compute_multiplicative_exposure(geo, supply, commodity, sector) lin = _compute_linear_exposure(geo, supply, commodity, sector) # They should produce different values (multiplicative compounds) assert mult != lin # Both should be positive assert mult > 0.0 assert lin > 0.0 def test_adding_overlap_increases_score(self): """Adding a non-zero overlap in any dimension increases the total.""" base = _compute_multiplicative_exposure(0.5, 0.0, 0.0, 0.0) with_supply = _compute_multiplicative_exposure(0.5, 0.3, 0.0, 0.0) assert with_supply > base def test_probabilistic_flag_uses_multiplicative(self): """compute_macro_impact with probabilistic=True uses multiplicative formula.""" event = GlobalEvent( event_id="evt-mult", event_types=["supply_disruption"], severity="critical", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-mult", geographic_revenue_mix={"US": 0.8}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) heuristic = compute_macro_impact(event, profile, probabilistic=False) probabilistic_result = compute_macro_impact(event, profile, probabilistic=True) # Both should produce positive scores assert heuristic.macro_impact_score > 0.0 assert probabilistic_result.macro_impact_score > 0.0 # They should produce different scores (different formulas) assert heuristic.macro_impact_score != probabilistic_result.macro_impact_score def test_probabilistic_false_preserves_linear(self): """probabilistic=False produces identical results to original behavior.""" event = GlobalEvent( event_id="evt-lin", event_types=["supply_disruption"], severity="high", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-lin", geographic_revenue_mix={"US": 0.5}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile, probabilistic=False) # Manually compute expected linear score geo = 0.5 # revenue mix for US supply = 1.0 # 1/1 supply regions match commodity = 1.0 # crude_oil matches severity = 0.75 # high expected_raw = severity * (0.35 * geo + 0.25 * supply + 0.25 * commodity + 0.15 * 0.0) # Single region → no resilience modifier assert math.isclose(record.macro_impact_score, expected_raw, abs_tol=1e-4) def test_zero_overlap_returns_zero_score_probabilistic(self): """Zero overlap still returns zero in probabilistic mode.""" event = GlobalEvent( event_id="evt-zero", event_types=["supply_disruption"], severity="critical", affected_regions=["JP"], affected_commodities=["gold"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-zero", geographic_revenue_mix={"US": 1.0}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) record = compute_macro_impact(event, profile, probabilistic=True) assert record.macro_impact_score == 0.0 def test_with_sector_probabilistic(self): """compute_macro_impact_with_sector supports probabilistic flag.""" event = GlobalEvent( event_id="evt-sec-prob", event_types=["supply_disruption"], severity="high", affected_regions=["US"], affected_sectors=["Energy"], confidence=0.9, ) profile = ExposureProfileSchema( company_id="comp-sec-prob", geographic_revenue_mix={"US": 0.5}, market_position_tier=MarketPositionTier.REGIONAL, ) heuristic = compute_macro_impact_with_sector( event, profile, "Energy", probabilistic=False, ) probabilistic = compute_macro_impact_with_sector( event, profile, "Energy", probabilistic=True, ) assert heuristic.macro_impact_score > 0.0 assert probabilistic.macro_impact_score > 0.0 def test_severity_preserved_in_probabilistic(self): """Severity mapping is preserved in probabilistic mode.""" profile = ExposureProfileSchema( company_id="comp-sev", geographic_revenue_mix={"US": 0.5}, supply_chain_regions=["US"], key_input_commodities=["crude_oil"], market_position_tier=MarketPositionTier.REGIONAL, ) event_low = GlobalEvent( event_id="evt-low-p", event_types=["supply_disruption"], severity="low", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.9, ) event_crit = GlobalEvent( event_id="evt-crit-p", event_types=["supply_disruption"], severity="critical", affected_regions=["US"], affected_commodities=["crude_oil"], confidence=0.9, ) low = compute_macro_impact(event_low, profile, probabilistic=True) crit = compute_macro_impact(event_crit, profile, probabilistic=True) assert crit.macro_impact_score >= low.macro_impact_score # --------------------------------------------------------------------------- # Conditional macro signal integration (Task 10.2, Requirements: 11.1–11.5) # --------------------------------------------------------------------------- class TestConditionalMacroModifier: """Tests for compute_conditional_macro_modifier.""" def test_agreeing_directions_amplify(self): """Bullish company + positive macro → modifier > 1.0.""" modifier = compute_conditional_macro_modifier( company_strength=0.5, company_direction="bullish", macro_impact=0.3, macro_direction="positive", ) assert modifier > 1.0 assert math.isclose(modifier, 1.3, abs_tol=1e-6) def test_disagreeing_directions_dampen(self): """Bullish company + negative macro → modifier < 1.0.""" modifier = compute_conditional_macro_modifier( company_strength=0.5, company_direction="bullish", macro_impact=0.3, macro_direction="negative", ) assert modifier < 1.0 assert math.isclose(modifier, 0.7, abs_tol=1e-6) def test_neutral_company_no_alignment(self): """Neutral company direction → modifier = 1.0.""" modifier = compute_conditional_macro_modifier( company_strength=0.5, company_direction="neutral", macro_impact=0.5, macro_direction="positive", ) assert math.isclose(modifier, 1.0, abs_tol=1e-6) def test_neutral_macro_no_alignment(self): """Neutral macro direction → modifier = 1.0.""" modifier = compute_conditional_macro_modifier( company_strength=0.5, company_direction="bullish", macro_impact=0.5, macro_direction="neutral", ) assert math.isclose(modifier, 1.0, abs_tol=1e-6) def test_clamped_to_max_1_5(self): """Large agreeing impact clamped to 1.5.""" modifier = compute_conditional_macro_modifier( company_strength=0.5, company_direction="bearish", macro_impact=0.8, macro_direction="negative", ) assert modifier <= 1.5 def test_clamped_to_min_0_5(self): """Large disagreeing impact clamped to 0.5.""" modifier = compute_conditional_macro_modifier( company_strength=0.5, company_direction="bearish", macro_impact=0.8, macro_direction="positive", ) assert modifier >= 0.5 def test_zero_macro_impact_no_change(self): """Zero macro impact → modifier = 1.0.""" modifier = compute_conditional_macro_modifier( company_strength=0.5, company_direction="bullish", macro_impact=0.0, macro_direction="positive", ) assert math.isclose(modifier, 1.0, abs_tol=1e-6) def test_bearish_negative_agree(self): """Bearish company + negative macro → they agree → modifier > 1.0.""" modifier = compute_conditional_macro_modifier( company_strength=0.5, company_direction="bearish", macro_impact=0.2, macro_direction="negative", ) assert modifier > 1.0 class TestIntegrateMacroSignals: """Tests for integrate_macro_signals.""" def _make_signal(self, doc_id: str, sentiment: float, impact: float): """Helper to create a minimal WeightedSignal-like object.""" from services.aggregation.scoring import SignalWeight, WeightedSignal weight = SignalWeight( recency=1.0, credibility=0.8, novelty_bonus=0.0, confidence_gate=1.0, market_ctx_multiplier=1.0, combined=0.8, ) return WeightedSignal( document_id=doc_id, weight=weight, sentiment_value=sentiment, impact_score=impact, ) def _make_macro_impact(self, score: float, direction: str): """Helper to create a MacroImpactRecord.""" return MacroImpactRecord( event_id="evt-1", company_id="comp-1", macro_impact_score=score, impact_direction=direction, ) def test_heuristic_mode_concatenates(self): """probabilistic=False → simple concatenation.""" company = [self._make_signal("c1", 0.5, 0.6)] macro = [self._make_signal("m1", 0.3, 0.4)] merged, modifier = integrate_macro_signals( company, macro, "bullish", [], probabilistic=False, ) assert len(merged) == 2 assert modifier == 1.0 def test_probabilistic_both_exist_applies_modifier(self): """Both company and macro → modifier applied to company signals.""" company = [self._make_signal("c1", 0.5, 0.6)] macro = [self._make_signal("m1", 0.3, 0.4)] impacts = [self._make_macro_impact(0.3, "positive")] merged, modifier = integrate_macro_signals( company, macro, "bullish", impacts, ticker="AAPL", probabilistic=True, ) # Modifier should be > 1.0 (agreeing directions) assert modifier > 1.0 # Only company signals returned (modified), not macro assert len(merged) == 1 # Impact score should be scaled by modifier assert merged[0].impact_score > 0.6 def test_probabilistic_macro_only_fallback(self): """Only macro signals → additive fallback.""" macro = [self._make_signal("m1", 0.3, 0.4)] impacts = [self._make_macro_impact(0.3, "positive")] merged, modifier = integrate_macro_signals( [], macro, "neutral", impacts, ticker="AAPL", probabilistic=True, ) assert len(merged) == 1 assert modifier == 1.0 def test_probabilistic_company_only_no_modifier(self): """Only company signals → modifier = 1.0.""" company = [self._make_signal("c1", 0.5, 0.6)] merged, modifier = integrate_macro_signals( company, [], "bullish", [], ticker="AAPL", probabilistic=True, ) assert len(merged) == 1 assert modifier == 1.0 assert merged[0].impact_score == 0.6 def test_probabilistic_disagreeing_dampens(self): """Disagreeing directions → modifier < 1.0, impact reduced.""" company = [self._make_signal("c1", 0.5, 0.6)] macro = [self._make_signal("m1", -0.3, 0.4)] impacts = [self._make_macro_impact(0.3, "negative")] merged, modifier = integrate_macro_signals( company, macro, "bullish", impacts, ticker="AAPL", probabilistic=True, ) assert modifier < 1.0 assert merged[0].impact_score < 0.6