511 lines
18 KiB
Python
511 lines
18 KiB
Python
"""Unit tests for the interpolation engine.
|
|
|
|
Tests core scoring functions: overlap computation, resilience modifiers,
|
|
macro impact scoring, default profile building, and direction determination.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from datetime import datetime, timezone
|
|
|
|
import pytest
|
|
|
|
from services.aggregation.interpolation import (
|
|
MacroImpactRecord,
|
|
apply_resilience_modifier,
|
|
build_default_profile,
|
|
compute_commodity_overlap,
|
|
compute_geographic_overlap,
|
|
compute_macro_impact,
|
|
compute_macro_impact_with_sector,
|
|
compute_supply_chain_overlap,
|
|
)
|
|
from services.extractor.event_classifier import GlobalEvent
|
|
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# compute_geographic_overlap
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestComputeGeographicOverlap:
|
|
def test_full_overlap(self):
|
|
result = compute_geographic_overlap(
|
|
["US", "CN"], {"US": 0.6, "CN": 0.4},
|
|
)
|
|
assert math.isclose(result, 1.0, abs_tol=1e-6)
|
|
|
|
def test_partial_overlap(self):
|
|
result = compute_geographic_overlap(
|
|
["US"], {"US": 0.6, "CN": 0.4},
|
|
)
|
|
assert math.isclose(result, 0.6, abs_tol=1e-6)
|
|
|
|
def test_no_overlap(self):
|
|
result = compute_geographic_overlap(
|
|
["JP"], {"US": 0.6, "CN": 0.4},
|
|
)
|
|
assert result == 0.0
|
|
|
|
def test_empty_event_regions(self):
|
|
assert compute_geographic_overlap([], {"US": 0.5}) == 0.0
|
|
|
|
def test_empty_revenue_mix(self):
|
|
assert compute_geographic_overlap(["US"], {}) == 0.0
|
|
|
|
def test_case_insensitive(self):
|
|
result = compute_geographic_overlap(
|
|
["us", "cn"], {"US": 0.6, "CN": 0.4},
|
|
)
|
|
assert math.isclose(result, 1.0, abs_tol=1e-6)
|
|
|
|
def test_clamped_to_one(self):
|
|
# Even if revenue mix sums > 1, result is clamped
|
|
result = compute_geographic_overlap(
|
|
["US", "CN"], {"US": 0.7, "CN": 0.6},
|
|
)
|
|
assert result <= 1.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# compute_supply_chain_overlap
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestComputeSupplyChainOverlap:
|
|
def test_full_overlap(self):
|
|
result = compute_supply_chain_overlap(["US", "CN"], ["US", "CN"])
|
|
assert result == 1.0
|
|
|
|
def test_partial_overlap(self):
|
|
result = compute_supply_chain_overlap(["US"], ["US", "CN"])
|
|
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
|
|
|
def test_no_overlap(self):
|
|
result = compute_supply_chain_overlap(["JP"], ["US", "CN"])
|
|
assert result == 0.0
|
|
|
|
def test_empty_event_regions(self):
|
|
assert compute_supply_chain_overlap([], ["US"]) == 0.0
|
|
|
|
def test_empty_supply_regions(self):
|
|
assert compute_supply_chain_overlap(["US"], []) == 0.0
|
|
|
|
def test_case_insensitive(self):
|
|
result = compute_supply_chain_overlap(["us"], ["US", "CN"])
|
|
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# compute_commodity_overlap
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestComputeCommodityOverlap:
|
|
def test_full_overlap(self):
|
|
result = compute_commodity_overlap(
|
|
["crude_oil", "natural_gas"], ["crude_oil", "natural_gas"],
|
|
)
|
|
assert result == 1.0
|
|
|
|
def test_partial_overlap(self):
|
|
result = compute_commodity_overlap(
|
|
["crude_oil"], ["crude_oil", "natural_gas"],
|
|
)
|
|
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
|
|
|
def test_no_overlap(self):
|
|
result = compute_commodity_overlap(["gold"], ["crude_oil"])
|
|
assert result == 0.0
|
|
|
|
def test_empty_event_commodities(self):
|
|
assert compute_commodity_overlap([], ["crude_oil"]) == 0.0
|
|
|
|
def test_empty_company_commodities(self):
|
|
assert compute_commodity_overlap(["crude_oil"], []) == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# apply_resilience_modifier
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestApplyResilienceModifier:
|
|
def test_global_leader_dampens(self):
|
|
result = apply_resilience_modifier(0.5, "global_leader", True)
|
|
assert math.isclose(result, 0.35, abs_tol=1e-6)
|
|
|
|
def test_domestic_amplifies(self):
|
|
result = apply_resilience_modifier(0.5, "domestic", True)
|
|
assert math.isclose(result, 0.6, abs_tol=1e-6)
|
|
|
|
def test_regional_no_change(self):
|
|
result = apply_resilience_modifier(0.5, "regional", True)
|
|
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
|
|
|
def test_no_modifier_for_domestic_event(self):
|
|
result = apply_resilience_modifier(0.5, "global_leader", False)
|
|
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
|
|
|
def test_clamped_to_one(self):
|
|
result = apply_resilience_modifier(0.9, "domestic", True)
|
|
assert result <= 1.0
|
|
|
|
def test_clamped_to_zero(self):
|
|
result = apply_resilience_modifier(0.0, "domestic", True)
|
|
assert result >= 0.0
|
|
|
|
def test_tier_ordering_for_international(self):
|
|
"""global_leader <= multinational <= regional <= domestic."""
|
|
raw = 0.5
|
|
gl = apply_resilience_modifier(raw, "global_leader", True)
|
|
mn = apply_resilience_modifier(raw, "multinational", True)
|
|
rg = apply_resilience_modifier(raw, "regional", True)
|
|
dm = apply_resilience_modifier(raw, "domestic", True)
|
|
assert gl <= mn <= rg <= dm
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# compute_macro_impact — zero overlap
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestComputeMacroImpactZeroOverlap:
|
|
def test_zero_overlap_returns_zero_score(self):
|
|
event = GlobalEvent(
|
|
event_id="evt-1",
|
|
event_types=["supply_disruption"],
|
|
severity="critical",
|
|
affected_regions=["JP"],
|
|
affected_sectors=["Energy"],
|
|
affected_commodities=["gold"],
|
|
confidence=0.9,
|
|
)
|
|
profile = ExposureProfileSchema(
|
|
company_id="comp-1",
|
|
geographic_revenue_mix={"US": 1.0},
|
|
supply_chain_regions=["US"],
|
|
key_input_commodities=["crude_oil"],
|
|
market_position_tier=MarketPositionTier.REGIONAL,
|
|
)
|
|
record = compute_macro_impact(event, profile)
|
|
assert record.macro_impact_score == 0.0
|
|
assert record.contributing_factors == []
|
|
assert record.confidence == 0.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# compute_macro_impact — basic scoring
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestComputeMacroImpactScoring:
|
|
def test_score_in_bounds(self):
|
|
event = GlobalEvent(
|
|
event_id="evt-2",
|
|
event_types=["supply_disruption"],
|
|
severity="critical",
|
|
affected_regions=["US"],
|
|
affected_sectors=["Energy"],
|
|
affected_commodities=["crude_oil"],
|
|
confidence=0.9,
|
|
)
|
|
profile = ExposureProfileSchema(
|
|
company_id="comp-2",
|
|
geographic_revenue_mix={"US": 0.8},
|
|
supply_chain_regions=["US"],
|
|
key_input_commodities=["crude_oil"],
|
|
market_position_tier=MarketPositionTier.REGIONAL,
|
|
)
|
|
record = compute_macro_impact(event, profile)
|
|
assert 0.0 <= record.macro_impact_score <= 1.0
|
|
assert record.macro_impact_score > 0.0
|
|
assert len(record.contributing_factors) > 0
|
|
|
|
def test_higher_severity_higher_score(self):
|
|
"""Critical severity should produce >= score than low severity."""
|
|
profile = ExposureProfileSchema(
|
|
company_id="comp-3",
|
|
geographic_revenue_mix={"US": 0.5},
|
|
supply_chain_regions=["US"],
|
|
key_input_commodities=["crude_oil"],
|
|
market_position_tier=MarketPositionTier.REGIONAL,
|
|
)
|
|
event_low = GlobalEvent(
|
|
event_id="evt-low",
|
|
event_types=["supply_disruption"],
|
|
severity="low",
|
|
affected_regions=["US"],
|
|
affected_commodities=["crude_oil"],
|
|
confidence=0.9,
|
|
)
|
|
event_critical = GlobalEvent(
|
|
event_id="evt-crit",
|
|
event_types=["supply_disruption"],
|
|
severity="critical",
|
|
affected_regions=["US"],
|
|
affected_commodities=["crude_oil"],
|
|
confidence=0.9,
|
|
)
|
|
low_record = compute_macro_impact(event_low, profile)
|
|
crit_record = compute_macro_impact(event_critical, profile)
|
|
assert crit_record.macro_impact_score >= low_record.macro_impact_score
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mixed direction
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMixedDirection:
|
|
def test_mixed_when_both_positive_and_negative(self):
|
|
"""demand_shift (positive) + supply_disruption (negative) → mixed."""
|
|
event = GlobalEvent(
|
|
event_id="evt-mix",
|
|
event_types=["demand_shift", "supply_disruption"],
|
|
severity="high",
|
|
affected_regions=["US"],
|
|
affected_commodities=["crude_oil"],
|
|
confidence=0.8,
|
|
)
|
|
profile = ExposureProfileSchema(
|
|
company_id="comp-mix",
|
|
geographic_revenue_mix={"US": 0.5},
|
|
supply_chain_regions=["US"],
|
|
key_input_commodities=["crude_oil"],
|
|
market_position_tier=MarketPositionTier.REGIONAL,
|
|
)
|
|
record = compute_macro_impact(event, profile)
|
|
assert record.impact_direction == "mixed"
|
|
# Both positive and negative factors should be in contributing_factors
|
|
factors_str = " ".join(record.contributing_factors)
|
|
assert "positive_types:" in factors_str
|
|
assert "negative_types:" in factors_str
|
|
|
|
def test_negative_only(self):
|
|
event = GlobalEvent(
|
|
event_id="evt-neg",
|
|
event_types=["supply_disruption", "cost_increase"],
|
|
severity="high",
|
|
affected_regions=["US"],
|
|
confidence=0.8,
|
|
)
|
|
profile = ExposureProfileSchema(
|
|
company_id="comp-neg",
|
|
geographic_revenue_mix={"US": 0.5},
|
|
market_position_tier=MarketPositionTier.REGIONAL,
|
|
)
|
|
record = compute_macro_impact(event, profile)
|
|
assert record.impact_direction == "negative"
|
|
|
|
def test_positive_only(self):
|
|
event = GlobalEvent(
|
|
event_id="evt-pos",
|
|
event_types=["demand_shift"],
|
|
severity="moderate",
|
|
affected_regions=["US"],
|
|
confidence=0.8,
|
|
)
|
|
profile = ExposureProfileSchema(
|
|
company_id="comp-pos",
|
|
geographic_revenue_mix={"US": 0.5},
|
|
market_position_tier=MarketPositionTier.REGIONAL,
|
|
)
|
|
record = compute_macro_impact(event, profile)
|
|
assert record.impact_direction == "positive"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# compute_macro_impact_with_sector
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestComputeMacroImpactWithSector:
|
|
def test_sector_match_increases_score(self):
|
|
event = GlobalEvent(
|
|
event_id="evt-sec",
|
|
event_types=["supply_disruption"],
|
|
severity="high",
|
|
affected_regions=["US"],
|
|
affected_sectors=["Energy"],
|
|
confidence=0.9,
|
|
)
|
|
profile = ExposureProfileSchema(
|
|
company_id="comp-sec",
|
|
geographic_revenue_mix={"US": 0.5},
|
|
market_position_tier=MarketPositionTier.REGIONAL,
|
|
)
|
|
without_sector = compute_macro_impact_with_sector(event, profile, "")
|
|
with_sector = compute_macro_impact_with_sector(event, profile, "Energy")
|
|
assert with_sector.macro_impact_score >= without_sector.macro_impact_score
|
|
|
|
def test_sector_no_match(self):
|
|
event = GlobalEvent(
|
|
event_id="evt-sec2",
|
|
event_types=["supply_disruption"],
|
|
severity="high",
|
|
affected_regions=["US"],
|
|
affected_sectors=["Energy"],
|
|
confidence=0.9,
|
|
)
|
|
profile = ExposureProfileSchema(
|
|
company_id="comp-sec2",
|
|
geographic_revenue_mix={"US": 0.5},
|
|
market_position_tier=MarketPositionTier.REGIONAL,
|
|
)
|
|
record = compute_macro_impact_with_sector(event, profile, "Financials")
|
|
# No sector match, but still has geo overlap
|
|
assert record.macro_impact_score > 0.0
|
|
factors_str = " ".join(record.contributing_factors)
|
|
assert "sector_match" not in factors_str
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# build_default_profile
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildDefaultProfile:
|
|
@pytest.mark.parametrize("cap,expected_tier", [
|
|
("large_cap", "global_leader"),
|
|
("mid_cap", "multinational"),
|
|
("small_cap", "regional"),
|
|
("micro_cap", "domestic"),
|
|
])
|
|
def test_market_cap_to_tier_mapping(self, cap, expected_tier):
|
|
profile = build_default_profile("Energy", "Oil & Gas", cap)
|
|
tier_val = profile.market_position_tier
|
|
if isinstance(tier_val, MarketPositionTier):
|
|
tier_val = tier_val.value
|
|
assert tier_val == expected_tier
|
|
|
|
def test_has_non_empty_geo_mix(self):
|
|
profile = build_default_profile("Energy", "Oil & Gas", "large_cap")
|
|
assert len(profile.geographic_revenue_mix) > 0
|
|
|
|
def test_source_is_inferred(self):
|
|
profile = build_default_profile("Energy", "Oil & Gas", "mid_cap")
|
|
assert profile.source == "inferred"
|
|
|
|
def test_unknown_sector_uses_default_geo(self):
|
|
profile = build_default_profile("UnknownSector", "Unknown", "small_cap")
|
|
assert len(profile.geographic_revenue_mix) > 0
|
|
|
|
def test_energy_sector_has_commodities(self):
|
|
profile = build_default_profile("Energy", "Oil & Gas", "large_cap")
|
|
assert len(profile.key_input_commodities) > 0
|
|
assert "crude_oil" in profile.key_input_commodities
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MacroImpactRecord dataclass
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMacroImpactRecord:
|
|
def test_defaults(self):
|
|
record = MacroImpactRecord()
|
|
assert record.event_id == ""
|
|
assert record.macro_impact_score == 0.0
|
|
assert record.impact_direction == "neutral"
|
|
assert record.contributing_factors == []
|
|
assert record.confidence == 0.5
|
|
assert record.computed_at is not None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Low-confidence event exclusion (Requirements: 10.1)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
from services.aggregation.interpolation import (
|
|
filter_low_confidence_events,
|
|
apply_accelerated_decay,
|
|
compute_standard_recency_decay,
|
|
DEFAULT_CONFIDENCE_THRESHOLD,
|
|
ACCELERATED_DECAY_MULTIPLIER,
|
|
)
|
|
|
|
|
|
class TestFilterLowConfidenceEvents:
|
|
def test_excludes_below_threshold(self):
|
|
events = [
|
|
GlobalEvent(event_id="e1", confidence=0.3),
|
|
GlobalEvent(event_id="e2", confidence=0.5),
|
|
GlobalEvent(event_id="e3", confidence=0.1),
|
|
]
|
|
result = filter_low_confidence_events(events, confidence_threshold=0.4)
|
|
assert len(result) == 1
|
|
assert result[0].event_id == "e2"
|
|
|
|
def test_includes_at_threshold(self):
|
|
events = [GlobalEvent(event_id="e1", confidence=0.4)]
|
|
result = filter_low_confidence_events(events, confidence_threshold=0.4)
|
|
assert len(result) == 1
|
|
|
|
def test_empty_list(self):
|
|
assert filter_low_confidence_events([], confidence_threshold=0.4) == []
|
|
|
|
def test_all_pass(self):
|
|
events = [
|
|
GlobalEvent(event_id="e1", confidence=0.8),
|
|
GlobalEvent(event_id="e2", confidence=0.9),
|
|
]
|
|
result = filter_low_confidence_events(events, confidence_threshold=0.4)
|
|
assert len(result) == 2
|
|
|
|
def test_all_excluded(self):
|
|
events = [
|
|
GlobalEvent(event_id="e1", confidence=0.1),
|
|
GlobalEvent(event_id="e2", confidence=0.2),
|
|
]
|
|
result = filter_low_confidence_events(events, confidence_threshold=0.4)
|
|
assert len(result) == 0
|
|
|
|
def test_default_threshold(self):
|
|
assert DEFAULT_CONFIDENCE_THRESHOLD == 0.4
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Accelerated decay for stale short-term events (Requirements: 10.2)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAcceleratedDecay:
|
|
def test_standard_decay_for_non_short_term(self):
|
|
standard = compute_standard_recency_decay(72.0)
|
|
accelerated = apply_accelerated_decay(72.0, "medium_term")
|
|
assert accelerated == standard
|
|
|
|
def test_standard_decay_for_young_short_term(self):
|
|
"""Short-term events within 48h get standard decay."""
|
|
standard = compute_standard_recency_decay(24.0)
|
|
accelerated = apply_accelerated_decay(24.0, "short_term")
|
|
assert accelerated == standard
|
|
|
|
def test_accelerated_for_stale_short_term(self):
|
|
"""Short-term events older than 48h get accelerated decay."""
|
|
age = 72.0
|
|
standard = compute_standard_recency_decay(age)
|
|
accelerated = apply_accelerated_decay(age, "short_term")
|
|
assert accelerated < standard
|
|
|
|
def test_accelerated_decay_multiplier(self):
|
|
age = 72.0
|
|
standard = compute_standard_recency_decay(age)
|
|
accelerated = apply_accelerated_decay(age, "short_term")
|
|
assert abs(accelerated - standard * ACCELERATED_DECAY_MULTIPLIER) < 1e-9
|
|
|
|
def test_long_term_no_acceleration(self):
|
|
standard = compute_standard_recency_decay(100.0)
|
|
result = apply_accelerated_decay(100.0, "long_term")
|
|
assert result == standard
|
|
|
|
def test_zero_age(self):
|
|
result = apply_accelerated_decay(0.0, "short_term")
|
|
assert result == 1.0
|
|
|
|
def test_standard_decay_positive(self):
|
|
result = compute_standard_recency_decay(168.0)
|
|
assert 0.0 < result < 1.0
|