Files
stonks-oracle/tests/test_interpolation.py
T

511 lines
18 KiB
Python

"""Unit tests for the interpolation engine.
Tests core scoring functions: overlap computation, resilience modifiers,
macro impact scoring, default profile building, and direction determination.
"""
from __future__ import annotations
import math
from datetime import datetime, timezone
import pytest
from services.aggregation.interpolation import (
MacroImpactRecord,
apply_resilience_modifier,
build_default_profile,
compute_commodity_overlap,
compute_geographic_overlap,
compute_macro_impact,
compute_macro_impact_with_sector,
compute_supply_chain_overlap,
)
from services.extractor.event_classifier import GlobalEvent
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
# ---------------------------------------------------------------------------
# compute_geographic_overlap
# ---------------------------------------------------------------------------
class TestComputeGeographicOverlap:
def test_full_overlap(self):
result = compute_geographic_overlap(
["US", "CN"], {"US": 0.6, "CN": 0.4},
)
assert math.isclose(result, 1.0, abs_tol=1e-6)
def test_partial_overlap(self):
result = compute_geographic_overlap(
["US"], {"US": 0.6, "CN": 0.4},
)
assert math.isclose(result, 0.6, abs_tol=1e-6)
def test_no_overlap(self):
result = compute_geographic_overlap(
["JP"], {"US": 0.6, "CN": 0.4},
)
assert result == 0.0
def test_empty_event_regions(self):
assert compute_geographic_overlap([], {"US": 0.5}) == 0.0
def test_empty_revenue_mix(self):
assert compute_geographic_overlap(["US"], {}) == 0.0
def test_case_insensitive(self):
result = compute_geographic_overlap(
["us", "cn"], {"US": 0.6, "CN": 0.4},
)
assert math.isclose(result, 1.0, abs_tol=1e-6)
def test_clamped_to_one(self):
# Even if revenue mix sums > 1, result is clamped
result = compute_geographic_overlap(
["US", "CN"], {"US": 0.7, "CN": 0.6},
)
assert result <= 1.0
# ---------------------------------------------------------------------------
# compute_supply_chain_overlap
# ---------------------------------------------------------------------------
class TestComputeSupplyChainOverlap:
def test_full_overlap(self):
result = compute_supply_chain_overlap(["US", "CN"], ["US", "CN"])
assert result == 1.0
def test_partial_overlap(self):
result = compute_supply_chain_overlap(["US"], ["US", "CN"])
assert math.isclose(result, 0.5, abs_tol=1e-6)
def test_no_overlap(self):
result = compute_supply_chain_overlap(["JP"], ["US", "CN"])
assert result == 0.0
def test_empty_event_regions(self):
assert compute_supply_chain_overlap([], ["US"]) == 0.0
def test_empty_supply_regions(self):
assert compute_supply_chain_overlap(["US"], []) == 0.0
def test_case_insensitive(self):
result = compute_supply_chain_overlap(["us"], ["US", "CN"])
assert math.isclose(result, 0.5, abs_tol=1e-6)
# ---------------------------------------------------------------------------
# compute_commodity_overlap
# ---------------------------------------------------------------------------
class TestComputeCommodityOverlap:
def test_full_overlap(self):
result = compute_commodity_overlap(
["crude_oil", "natural_gas"], ["crude_oil", "natural_gas"],
)
assert result == 1.0
def test_partial_overlap(self):
result = compute_commodity_overlap(
["crude_oil"], ["crude_oil", "natural_gas"],
)
assert math.isclose(result, 0.5, abs_tol=1e-6)
def test_no_overlap(self):
result = compute_commodity_overlap(["gold"], ["crude_oil"])
assert result == 0.0
def test_empty_event_commodities(self):
assert compute_commodity_overlap([], ["crude_oil"]) == 0.0
def test_empty_company_commodities(self):
assert compute_commodity_overlap(["crude_oil"], []) == 0.0
# ---------------------------------------------------------------------------
# apply_resilience_modifier
# ---------------------------------------------------------------------------
class TestApplyResilienceModifier:
def test_global_leader_dampens(self):
result = apply_resilience_modifier(0.5, "global_leader", True)
assert math.isclose(result, 0.35, abs_tol=1e-6)
def test_domestic_amplifies(self):
result = apply_resilience_modifier(0.5, "domestic", True)
assert math.isclose(result, 0.6, abs_tol=1e-6)
def test_regional_no_change(self):
result = apply_resilience_modifier(0.5, "regional", True)
assert math.isclose(result, 0.5, abs_tol=1e-6)
def test_no_modifier_for_domestic_event(self):
result = apply_resilience_modifier(0.5, "global_leader", False)
assert math.isclose(result, 0.5, abs_tol=1e-6)
def test_clamped_to_one(self):
result = apply_resilience_modifier(0.9, "domestic", True)
assert result <= 1.0
def test_clamped_to_zero(self):
result = apply_resilience_modifier(0.0, "domestic", True)
assert result >= 0.0
def test_tier_ordering_for_international(self):
"""global_leader <= multinational <= regional <= domestic."""
raw = 0.5
gl = apply_resilience_modifier(raw, "global_leader", True)
mn = apply_resilience_modifier(raw, "multinational", True)
rg = apply_resilience_modifier(raw, "regional", True)
dm = apply_resilience_modifier(raw, "domestic", True)
assert gl <= mn <= rg <= dm
# ---------------------------------------------------------------------------
# compute_macro_impact — zero overlap
# ---------------------------------------------------------------------------
class TestComputeMacroImpactZeroOverlap:
def test_zero_overlap_returns_zero_score(self):
event = GlobalEvent(
event_id="evt-1",
event_types=["supply_disruption"],
severity="critical",
affected_regions=["JP"],
affected_sectors=["Energy"],
affected_commodities=["gold"],
confidence=0.9,
)
profile = ExposureProfileSchema(
company_id="comp-1",
geographic_revenue_mix={"US": 1.0},
supply_chain_regions=["US"],
key_input_commodities=["crude_oil"],
market_position_tier=MarketPositionTier.REGIONAL,
)
record = compute_macro_impact(event, profile)
assert record.macro_impact_score == 0.0
assert record.contributing_factors == []
assert record.confidence == 0.0
# ---------------------------------------------------------------------------
# compute_macro_impact — basic scoring
# ---------------------------------------------------------------------------
class TestComputeMacroImpactScoring:
def test_score_in_bounds(self):
event = GlobalEvent(
event_id="evt-2",
event_types=["supply_disruption"],
severity="critical",
affected_regions=["US"],
affected_sectors=["Energy"],
affected_commodities=["crude_oil"],
confidence=0.9,
)
profile = ExposureProfileSchema(
company_id="comp-2",
geographic_revenue_mix={"US": 0.8},
supply_chain_regions=["US"],
key_input_commodities=["crude_oil"],
market_position_tier=MarketPositionTier.REGIONAL,
)
record = compute_macro_impact(event, profile)
assert 0.0 <= record.macro_impact_score <= 1.0
assert record.macro_impact_score > 0.0
assert len(record.contributing_factors) > 0
def test_higher_severity_higher_score(self):
"""Critical severity should produce >= score than low severity."""
profile = ExposureProfileSchema(
company_id="comp-3",
geographic_revenue_mix={"US": 0.5},
supply_chain_regions=["US"],
key_input_commodities=["crude_oil"],
market_position_tier=MarketPositionTier.REGIONAL,
)
event_low = GlobalEvent(
event_id="evt-low",
event_types=["supply_disruption"],
severity="low",
affected_regions=["US"],
affected_commodities=["crude_oil"],
confidence=0.9,
)
event_critical = GlobalEvent(
event_id="evt-crit",
event_types=["supply_disruption"],
severity="critical",
affected_regions=["US"],
affected_commodities=["crude_oil"],
confidence=0.9,
)
low_record = compute_macro_impact(event_low, profile)
crit_record = compute_macro_impact(event_critical, profile)
assert crit_record.macro_impact_score >= low_record.macro_impact_score
# ---------------------------------------------------------------------------
# Mixed direction
# ---------------------------------------------------------------------------
class TestMixedDirection:
def test_mixed_when_both_positive_and_negative(self):
"""demand_shift (positive) + supply_disruption (negative) → mixed."""
event = GlobalEvent(
event_id="evt-mix",
event_types=["demand_shift", "supply_disruption"],
severity="high",
affected_regions=["US"],
affected_commodities=["crude_oil"],
confidence=0.8,
)
profile = ExposureProfileSchema(
company_id="comp-mix",
geographic_revenue_mix={"US": 0.5},
supply_chain_regions=["US"],
key_input_commodities=["crude_oil"],
market_position_tier=MarketPositionTier.REGIONAL,
)
record = compute_macro_impact(event, profile)
assert record.impact_direction == "mixed"
# Both positive and negative factors should be in contributing_factors
factors_str = " ".join(record.contributing_factors)
assert "positive_types:" in factors_str
assert "negative_types:" in factors_str
def test_negative_only(self):
event = GlobalEvent(
event_id="evt-neg",
event_types=["supply_disruption", "cost_increase"],
severity="high",
affected_regions=["US"],
confidence=0.8,
)
profile = ExposureProfileSchema(
company_id="comp-neg",
geographic_revenue_mix={"US": 0.5},
market_position_tier=MarketPositionTier.REGIONAL,
)
record = compute_macro_impact(event, profile)
assert record.impact_direction == "negative"
def test_positive_only(self):
event = GlobalEvent(
event_id="evt-pos",
event_types=["demand_shift"],
severity="moderate",
affected_regions=["US"],
confidence=0.8,
)
profile = ExposureProfileSchema(
company_id="comp-pos",
geographic_revenue_mix={"US": 0.5},
market_position_tier=MarketPositionTier.REGIONAL,
)
record = compute_macro_impact(event, profile)
assert record.impact_direction == "positive"
# ---------------------------------------------------------------------------
# compute_macro_impact_with_sector
# ---------------------------------------------------------------------------
class TestComputeMacroImpactWithSector:
def test_sector_match_increases_score(self):
event = GlobalEvent(
event_id="evt-sec",
event_types=["supply_disruption"],
severity="high",
affected_regions=["US"],
affected_sectors=["Energy"],
confidence=0.9,
)
profile = ExposureProfileSchema(
company_id="comp-sec",
geographic_revenue_mix={"US": 0.5},
market_position_tier=MarketPositionTier.REGIONAL,
)
without_sector = compute_macro_impact_with_sector(event, profile, "")
with_sector = compute_macro_impact_with_sector(event, profile, "Energy")
assert with_sector.macro_impact_score >= without_sector.macro_impact_score
def test_sector_no_match(self):
event = GlobalEvent(
event_id="evt-sec2",
event_types=["supply_disruption"],
severity="high",
affected_regions=["US"],
affected_sectors=["Energy"],
confidence=0.9,
)
profile = ExposureProfileSchema(
company_id="comp-sec2",
geographic_revenue_mix={"US": 0.5},
market_position_tier=MarketPositionTier.REGIONAL,
)
record = compute_macro_impact_with_sector(event, profile, "Financials")
# No sector match, but still has geo overlap
assert record.macro_impact_score > 0.0
factors_str = " ".join(record.contributing_factors)
assert "sector_match" not in factors_str
# ---------------------------------------------------------------------------
# build_default_profile
# ---------------------------------------------------------------------------
class TestBuildDefaultProfile:
@pytest.mark.parametrize("cap,expected_tier", [
("large_cap", "global_leader"),
("mid_cap", "multinational"),
("small_cap", "regional"),
("micro_cap", "domestic"),
])
def test_market_cap_to_tier_mapping(self, cap, expected_tier):
profile = build_default_profile("Energy", "Oil & Gas", cap)
tier_val = profile.market_position_tier
if isinstance(tier_val, MarketPositionTier):
tier_val = tier_val.value
assert tier_val == expected_tier
def test_has_non_empty_geo_mix(self):
profile = build_default_profile("Energy", "Oil & Gas", "large_cap")
assert len(profile.geographic_revenue_mix) > 0
def test_source_is_inferred(self):
profile = build_default_profile("Energy", "Oil & Gas", "mid_cap")
assert profile.source == "inferred"
def test_unknown_sector_uses_default_geo(self):
profile = build_default_profile("UnknownSector", "Unknown", "small_cap")
assert len(profile.geographic_revenue_mix) > 0
def test_energy_sector_has_commodities(self):
profile = build_default_profile("Energy", "Oil & Gas", "large_cap")
assert len(profile.key_input_commodities) > 0
assert "crude_oil" in profile.key_input_commodities
# ---------------------------------------------------------------------------
# MacroImpactRecord dataclass
# ---------------------------------------------------------------------------
class TestMacroImpactRecord:
def test_defaults(self):
record = MacroImpactRecord()
assert record.event_id == ""
assert record.macro_impact_score == 0.0
assert record.impact_direction == "neutral"
assert record.contributing_factors == []
assert record.confidence == 0.5
assert record.computed_at is not None
# ---------------------------------------------------------------------------
# Low-confidence event exclusion (Requirements: 10.1)
# ---------------------------------------------------------------------------
from services.aggregation.interpolation import (
filter_low_confidence_events,
apply_accelerated_decay,
compute_standard_recency_decay,
DEFAULT_CONFIDENCE_THRESHOLD,
ACCELERATED_DECAY_MULTIPLIER,
)
class TestFilterLowConfidenceEvents:
def test_excludes_below_threshold(self):
events = [
GlobalEvent(event_id="e1", confidence=0.3),
GlobalEvent(event_id="e2", confidence=0.5),
GlobalEvent(event_id="e3", confidence=0.1),
]
result = filter_low_confidence_events(events, confidence_threshold=0.4)
assert len(result) == 1
assert result[0].event_id == "e2"
def test_includes_at_threshold(self):
events = [GlobalEvent(event_id="e1", confidence=0.4)]
result = filter_low_confidence_events(events, confidence_threshold=0.4)
assert len(result) == 1
def test_empty_list(self):
assert filter_low_confidence_events([], confidence_threshold=0.4) == []
def test_all_pass(self):
events = [
GlobalEvent(event_id="e1", confidence=0.8),
GlobalEvent(event_id="e2", confidence=0.9),
]
result = filter_low_confidence_events(events, confidence_threshold=0.4)
assert len(result) == 2
def test_all_excluded(self):
events = [
GlobalEvent(event_id="e1", confidence=0.1),
GlobalEvent(event_id="e2", confidence=0.2),
]
result = filter_low_confidence_events(events, confidence_threshold=0.4)
assert len(result) == 0
def test_default_threshold(self):
assert DEFAULT_CONFIDENCE_THRESHOLD == 0.4
# ---------------------------------------------------------------------------
# Accelerated decay for stale short-term events (Requirements: 10.2)
# ---------------------------------------------------------------------------
class TestAcceleratedDecay:
def test_standard_decay_for_non_short_term(self):
standard = compute_standard_recency_decay(72.0)
accelerated = apply_accelerated_decay(72.0, "medium_term")
assert accelerated == standard
def test_standard_decay_for_young_short_term(self):
"""Short-term events within 48h get standard decay."""
standard = compute_standard_recency_decay(24.0)
accelerated = apply_accelerated_decay(24.0, "short_term")
assert accelerated == standard
def test_accelerated_for_stale_short_term(self):
"""Short-term events older than 48h get accelerated decay."""
age = 72.0
standard = compute_standard_recency_decay(age)
accelerated = apply_accelerated_decay(age, "short_term")
assert accelerated < standard
def test_accelerated_decay_multiplier(self):
age = 72.0
standard = compute_standard_recency_decay(age)
accelerated = apply_accelerated_decay(age, "short_term")
assert abs(accelerated - standard * ACCELERATED_DECAY_MULTIPLIER) < 1e-9
def test_long_term_no_acceleration(self):
standard = compute_standard_recency_decay(100.0)
result = apply_accelerated_decay(100.0, "long_term")
assert result == standard
def test_zero_age(self):
result = apply_accelerated_decay(0.0, "short_term")
assert result == 1.0
def test_standard_decay_positive(self):
result = compute_standard_recency_decay(168.0)
assert 0.0 < result < 1.0