feat: competitive intelligence & historical pattern matching layer
This commit is contained in:
@@ -0,0 +1,510 @@
|
||||
"""Unit tests for the interpolation engine.
|
||||
|
||||
Tests core scoring functions: overlap computation, resilience modifiers,
|
||||
macro impact scoring, default profile building, and direction determination.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from services.aggregation.interpolation import (
|
||||
MacroImpactRecord,
|
||||
apply_resilience_modifier,
|
||||
build_default_profile,
|
||||
compute_commodity_overlap,
|
||||
compute_geographic_overlap,
|
||||
compute_macro_impact,
|
||||
compute_macro_impact_with_sector,
|
||||
compute_supply_chain_overlap,
|
||||
)
|
||||
from services.extractor.event_classifier import GlobalEvent
|
||||
from services.shared.schemas import ExposureProfileSchema, MarketPositionTier
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_geographic_overlap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeGeographicOverlap:
|
||||
def test_full_overlap(self):
|
||||
result = compute_geographic_overlap(
|
||||
["US", "CN"], {"US": 0.6, "CN": 0.4},
|
||||
)
|
||||
assert math.isclose(result, 1.0, abs_tol=1e-6)
|
||||
|
||||
def test_partial_overlap(self):
|
||||
result = compute_geographic_overlap(
|
||||
["US"], {"US": 0.6, "CN": 0.4},
|
||||
)
|
||||
assert math.isclose(result, 0.6, abs_tol=1e-6)
|
||||
|
||||
def test_no_overlap(self):
|
||||
result = compute_geographic_overlap(
|
||||
["JP"], {"US": 0.6, "CN": 0.4},
|
||||
)
|
||||
assert result == 0.0
|
||||
|
||||
def test_empty_event_regions(self):
|
||||
assert compute_geographic_overlap([], {"US": 0.5}) == 0.0
|
||||
|
||||
def test_empty_revenue_mix(self):
|
||||
assert compute_geographic_overlap(["US"], {}) == 0.0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
result = compute_geographic_overlap(
|
||||
["us", "cn"], {"US": 0.6, "CN": 0.4},
|
||||
)
|
||||
assert math.isclose(result, 1.0, abs_tol=1e-6)
|
||||
|
||||
def test_clamped_to_one(self):
|
||||
# Even if revenue mix sums > 1, result is clamped
|
||||
result = compute_geographic_overlap(
|
||||
["US", "CN"], {"US": 0.7, "CN": 0.6},
|
||||
)
|
||||
assert result <= 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_supply_chain_overlap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeSupplyChainOverlap:
|
||||
def test_full_overlap(self):
|
||||
result = compute_supply_chain_overlap(["US", "CN"], ["US", "CN"])
|
||||
assert result == 1.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
result = compute_supply_chain_overlap(["US"], ["US", "CN"])
|
||||
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
||||
|
||||
def test_no_overlap(self):
|
||||
result = compute_supply_chain_overlap(["JP"], ["US", "CN"])
|
||||
assert result == 0.0
|
||||
|
||||
def test_empty_event_regions(self):
|
||||
assert compute_supply_chain_overlap([], ["US"]) == 0.0
|
||||
|
||||
def test_empty_supply_regions(self):
|
||||
assert compute_supply_chain_overlap(["US"], []) == 0.0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
result = compute_supply_chain_overlap(["us"], ["US", "CN"])
|
||||
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_commodity_overlap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeCommodityOverlap:
|
||||
def test_full_overlap(self):
|
||||
result = compute_commodity_overlap(
|
||||
["crude_oil", "natural_gas"], ["crude_oil", "natural_gas"],
|
||||
)
|
||||
assert result == 1.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
result = compute_commodity_overlap(
|
||||
["crude_oil"], ["crude_oil", "natural_gas"],
|
||||
)
|
||||
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
||||
|
||||
def test_no_overlap(self):
|
||||
result = compute_commodity_overlap(["gold"], ["crude_oil"])
|
||||
assert result == 0.0
|
||||
|
||||
def test_empty_event_commodities(self):
|
||||
assert compute_commodity_overlap([], ["crude_oil"]) == 0.0
|
||||
|
||||
def test_empty_company_commodities(self):
|
||||
assert compute_commodity_overlap(["crude_oil"], []) == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# apply_resilience_modifier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApplyResilienceModifier:
|
||||
def test_global_leader_dampens(self):
|
||||
result = apply_resilience_modifier(0.5, "global_leader", True)
|
||||
assert math.isclose(result, 0.35, abs_tol=1e-6)
|
||||
|
||||
def test_domestic_amplifies(self):
|
||||
result = apply_resilience_modifier(0.5, "domestic", True)
|
||||
assert math.isclose(result, 0.6, abs_tol=1e-6)
|
||||
|
||||
def test_regional_no_change(self):
|
||||
result = apply_resilience_modifier(0.5, "regional", True)
|
||||
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
||||
|
||||
def test_no_modifier_for_domestic_event(self):
|
||||
result = apply_resilience_modifier(0.5, "global_leader", False)
|
||||
assert math.isclose(result, 0.5, abs_tol=1e-6)
|
||||
|
||||
def test_clamped_to_one(self):
|
||||
result = apply_resilience_modifier(0.9, "domestic", True)
|
||||
assert result <= 1.0
|
||||
|
||||
def test_clamped_to_zero(self):
|
||||
result = apply_resilience_modifier(0.0, "domestic", True)
|
||||
assert result >= 0.0
|
||||
|
||||
def test_tier_ordering_for_international(self):
|
||||
"""global_leader <= multinational <= regional <= domestic."""
|
||||
raw = 0.5
|
||||
gl = apply_resilience_modifier(raw, "global_leader", True)
|
||||
mn = apply_resilience_modifier(raw, "multinational", True)
|
||||
rg = apply_resilience_modifier(raw, "regional", True)
|
||||
dm = apply_resilience_modifier(raw, "domestic", True)
|
||||
assert gl <= mn <= rg <= dm
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_macro_impact — zero overlap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeMacroImpactZeroOverlap:
|
||||
def test_zero_overlap_returns_zero_score(self):
|
||||
event = GlobalEvent(
|
||||
event_id="evt-1",
|
||||
event_types=["supply_disruption"],
|
||||
severity="critical",
|
||||
affected_regions=["JP"],
|
||||
affected_sectors=["Energy"],
|
||||
affected_commodities=["gold"],
|
||||
confidence=0.9,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-1",
|
||||
geographic_revenue_mix={"US": 1.0},
|
||||
supply_chain_regions=["US"],
|
||||
key_input_commodities=["crude_oil"],
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
record = compute_macro_impact(event, profile)
|
||||
assert record.macro_impact_score == 0.0
|
||||
assert record.contributing_factors == []
|
||||
assert record.confidence == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_macro_impact — basic scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeMacroImpactScoring:
|
||||
def test_score_in_bounds(self):
|
||||
event = GlobalEvent(
|
||||
event_id="evt-2",
|
||||
event_types=["supply_disruption"],
|
||||
severity="critical",
|
||||
affected_regions=["US"],
|
||||
affected_sectors=["Energy"],
|
||||
affected_commodities=["crude_oil"],
|
||||
confidence=0.9,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-2",
|
||||
geographic_revenue_mix={"US": 0.8},
|
||||
supply_chain_regions=["US"],
|
||||
key_input_commodities=["crude_oil"],
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
record = compute_macro_impact(event, profile)
|
||||
assert 0.0 <= record.macro_impact_score <= 1.0
|
||||
assert record.macro_impact_score > 0.0
|
||||
assert len(record.contributing_factors) > 0
|
||||
|
||||
def test_higher_severity_higher_score(self):
|
||||
"""Critical severity should produce >= score than low severity."""
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-3",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
supply_chain_regions=["US"],
|
||||
key_input_commodities=["crude_oil"],
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
event_low = GlobalEvent(
|
||||
event_id="evt-low",
|
||||
event_types=["supply_disruption"],
|
||||
severity="low",
|
||||
affected_regions=["US"],
|
||||
affected_commodities=["crude_oil"],
|
||||
confidence=0.9,
|
||||
)
|
||||
event_critical = GlobalEvent(
|
||||
event_id="evt-crit",
|
||||
event_types=["supply_disruption"],
|
||||
severity="critical",
|
||||
affected_regions=["US"],
|
||||
affected_commodities=["crude_oil"],
|
||||
confidence=0.9,
|
||||
)
|
||||
low_record = compute_macro_impact(event_low, profile)
|
||||
crit_record = compute_macro_impact(event_critical, profile)
|
||||
assert crit_record.macro_impact_score >= low_record.macro_impact_score
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mixed direction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMixedDirection:
|
||||
def test_mixed_when_both_positive_and_negative(self):
|
||||
"""demand_shift (positive) + supply_disruption (negative) → mixed."""
|
||||
event = GlobalEvent(
|
||||
event_id="evt-mix",
|
||||
event_types=["demand_shift", "supply_disruption"],
|
||||
severity="high",
|
||||
affected_regions=["US"],
|
||||
affected_commodities=["crude_oil"],
|
||||
confidence=0.8,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-mix",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
supply_chain_regions=["US"],
|
||||
key_input_commodities=["crude_oil"],
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
record = compute_macro_impact(event, profile)
|
||||
assert record.impact_direction == "mixed"
|
||||
# Both positive and negative factors should be in contributing_factors
|
||||
factors_str = " ".join(record.contributing_factors)
|
||||
assert "positive_types:" in factors_str
|
||||
assert "negative_types:" in factors_str
|
||||
|
||||
def test_negative_only(self):
|
||||
event = GlobalEvent(
|
||||
event_id="evt-neg",
|
||||
event_types=["supply_disruption", "cost_increase"],
|
||||
severity="high",
|
||||
affected_regions=["US"],
|
||||
confidence=0.8,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-neg",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
record = compute_macro_impact(event, profile)
|
||||
assert record.impact_direction == "negative"
|
||||
|
||||
def test_positive_only(self):
|
||||
event = GlobalEvent(
|
||||
event_id="evt-pos",
|
||||
event_types=["demand_shift"],
|
||||
severity="moderate",
|
||||
affected_regions=["US"],
|
||||
confidence=0.8,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-pos",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
record = compute_macro_impact(event, profile)
|
||||
assert record.impact_direction == "positive"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_macro_impact_with_sector
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeMacroImpactWithSector:
|
||||
def test_sector_match_increases_score(self):
|
||||
event = GlobalEvent(
|
||||
event_id="evt-sec",
|
||||
event_types=["supply_disruption"],
|
||||
severity="high",
|
||||
affected_regions=["US"],
|
||||
affected_sectors=["Energy"],
|
||||
confidence=0.9,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-sec",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
without_sector = compute_macro_impact_with_sector(event, profile, "")
|
||||
with_sector = compute_macro_impact_with_sector(event, profile, "Energy")
|
||||
assert with_sector.macro_impact_score >= without_sector.macro_impact_score
|
||||
|
||||
def test_sector_no_match(self):
|
||||
event = GlobalEvent(
|
||||
event_id="evt-sec2",
|
||||
event_types=["supply_disruption"],
|
||||
severity="high",
|
||||
affected_regions=["US"],
|
||||
affected_sectors=["Energy"],
|
||||
confidence=0.9,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-sec2",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
record = compute_macro_impact_with_sector(event, profile, "Financials")
|
||||
# No sector match, but still has geo overlap
|
||||
assert record.macro_impact_score > 0.0
|
||||
factors_str = " ".join(record.contributing_factors)
|
||||
assert "sector_match" not in factors_str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_default_profile
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildDefaultProfile:
|
||||
@pytest.mark.parametrize("cap,expected_tier", [
|
||||
("large_cap", "global_leader"),
|
||||
("mid_cap", "multinational"),
|
||||
("small_cap", "regional"),
|
||||
("micro_cap", "domestic"),
|
||||
])
|
||||
def test_market_cap_to_tier_mapping(self, cap, expected_tier):
|
||||
profile = build_default_profile("Energy", "Oil & Gas", cap)
|
||||
tier_val = profile.market_position_tier
|
||||
if isinstance(tier_val, MarketPositionTier):
|
||||
tier_val = tier_val.value
|
||||
assert tier_val == expected_tier
|
||||
|
||||
def test_has_non_empty_geo_mix(self):
|
||||
profile = build_default_profile("Energy", "Oil & Gas", "large_cap")
|
||||
assert len(profile.geographic_revenue_mix) > 0
|
||||
|
||||
def test_source_is_inferred(self):
|
||||
profile = build_default_profile("Energy", "Oil & Gas", "mid_cap")
|
||||
assert profile.source == "inferred"
|
||||
|
||||
def test_unknown_sector_uses_default_geo(self):
|
||||
profile = build_default_profile("UnknownSector", "Unknown", "small_cap")
|
||||
assert len(profile.geographic_revenue_mix) > 0
|
||||
|
||||
def test_energy_sector_has_commodities(self):
|
||||
profile = build_default_profile("Energy", "Oil & Gas", "large_cap")
|
||||
assert len(profile.key_input_commodities) > 0
|
||||
assert "crude_oil" in profile.key_input_commodities
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MacroImpactRecord dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMacroImpactRecord:
|
||||
def test_defaults(self):
|
||||
record = MacroImpactRecord()
|
||||
assert record.event_id == ""
|
||||
assert record.macro_impact_score == 0.0
|
||||
assert record.impact_direction == "neutral"
|
||||
assert record.contributing_factors == []
|
||||
assert record.confidence == 0.5
|
||||
assert record.computed_at is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-confidence event exclusion (Requirements: 10.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from services.aggregation.interpolation import (
|
||||
filter_low_confidence_events,
|
||||
apply_accelerated_decay,
|
||||
compute_standard_recency_decay,
|
||||
DEFAULT_CONFIDENCE_THRESHOLD,
|
||||
ACCELERATED_DECAY_MULTIPLIER,
|
||||
)
|
||||
|
||||
|
||||
class TestFilterLowConfidenceEvents:
|
||||
def test_excludes_below_threshold(self):
|
||||
events = [
|
||||
GlobalEvent(event_id="e1", confidence=0.3),
|
||||
GlobalEvent(event_id="e2", confidence=0.5),
|
||||
GlobalEvent(event_id="e3", confidence=0.1),
|
||||
]
|
||||
result = filter_low_confidence_events(events, confidence_threshold=0.4)
|
||||
assert len(result) == 1
|
||||
assert result[0].event_id == "e2"
|
||||
|
||||
def test_includes_at_threshold(self):
|
||||
events = [GlobalEvent(event_id="e1", confidence=0.4)]
|
||||
result = filter_low_confidence_events(events, confidence_threshold=0.4)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_empty_list(self):
|
||||
assert filter_low_confidence_events([], confidence_threshold=0.4) == []
|
||||
|
||||
def test_all_pass(self):
|
||||
events = [
|
||||
GlobalEvent(event_id="e1", confidence=0.8),
|
||||
GlobalEvent(event_id="e2", confidence=0.9),
|
||||
]
|
||||
result = filter_low_confidence_events(events, confidence_threshold=0.4)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_all_excluded(self):
|
||||
events = [
|
||||
GlobalEvent(event_id="e1", confidence=0.1),
|
||||
GlobalEvent(event_id="e2", confidence=0.2),
|
||||
]
|
||||
result = filter_low_confidence_events(events, confidence_threshold=0.4)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_default_threshold(self):
|
||||
assert DEFAULT_CONFIDENCE_THRESHOLD == 0.4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Accelerated decay for stale short-term events (Requirements: 10.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAcceleratedDecay:
|
||||
def test_standard_decay_for_non_short_term(self):
|
||||
standard = compute_standard_recency_decay(72.0)
|
||||
accelerated = apply_accelerated_decay(72.0, "medium_term")
|
||||
assert accelerated == standard
|
||||
|
||||
def test_standard_decay_for_young_short_term(self):
|
||||
"""Short-term events within 48h get standard decay."""
|
||||
standard = compute_standard_recency_decay(24.0)
|
||||
accelerated = apply_accelerated_decay(24.0, "short_term")
|
||||
assert accelerated == standard
|
||||
|
||||
def test_accelerated_for_stale_short_term(self):
|
||||
"""Short-term events older than 48h get accelerated decay."""
|
||||
age = 72.0
|
||||
standard = compute_standard_recency_decay(age)
|
||||
accelerated = apply_accelerated_decay(age, "short_term")
|
||||
assert accelerated < standard
|
||||
|
||||
def test_accelerated_decay_multiplier(self):
|
||||
age = 72.0
|
||||
standard = compute_standard_recency_decay(age)
|
||||
accelerated = apply_accelerated_decay(age, "short_term")
|
||||
assert abs(accelerated - standard * ACCELERATED_DECAY_MULTIPLIER) < 1e-9
|
||||
|
||||
def test_long_term_no_acceleration(self):
|
||||
standard = compute_standard_recency_decay(100.0)
|
||||
result = apply_accelerated_decay(100.0, "long_term")
|
||||
assert result == standard
|
||||
|
||||
def test_zero_age(self):
|
||||
result = apply_accelerated_decay(0.0, "short_term")
|
||||
assert result == 1.0
|
||||
|
||||
def test_standard_decay_positive(self):
|
||||
result = compute_standard_recency_decay(168.0)
|
||||
assert 0.0 < result < 1.0
|
||||
Reference in New Issue
Block a user