feat: signal math upgrade — probabilistic, regime-aware scoring pipeline
ci/woodpecker/push/test Pipeline was successful
ci/woodpecker/push/build-1 Pipeline was successful
ci/woodpecker/push/build-2 Pipeline was successful
ci/woodpecker/push/build-3 Pipeline was successful
ci/woodpecker/push/finalize Pipeline was successful
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
ci/woodpecker/push/test Pipeline was successful
ci/woodpecker/push/build-1 Pipeline was successful
ci/woodpecker/push/build-2 Pipeline was successful
ci/woodpecker/push/build-3 Pipeline was successful
ci/woodpecker/push/finalize Pipeline was successful
Build and Push / lint-and-test (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.adapters.broker_adapter name:broker-adapter]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.aggregation.worker name:aggregation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.extractor.worker name:extractor]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.ingestion.worker name:ingestion]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.lake_publisher.worker name:lake-publisher]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.parser.worker name:parser]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.recommendation.worker name:recommendation]) (push) Has been cancelled
Build and Push / build-services (map[cmd:python -m services.scheduler.app name:scheduler]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.api.app:app --host 0.0.0.0 --port 8000 name:query-api]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.risk.app:app --host 0.0.0.0 --port 8000 name:risk]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.symbol_registry.app:app --host 0.0.0.0 --port 8000 name:symbol-registry]) (push) Has been cancelled
Build and Push / build-services (map[cmd:uvicorn services.trading.app:app --host 0.0.0.0 --port 8000 name:trading-engine]) (push) Has been cancelled
Build and Push / build-dashboard (push) Has been cancelled
Build and Push / build-superset (push) Has been cancelled
Build and Push / integration-test (push) Has been cancelled
Build and Push / beta-gate (push) Has been cancelled
Implement full probabilistic signal processing pipeline gated behind probabilistic_scoring_enabled feature flag in risk_configs: - Bayesian log-likelihood accumulator with Beta posterior and entropy - Regime detector (trend-following, panic, mean-reversion, uncertainty) - Source accuracy tracker with per-source historical prediction accuracy - Sigmoid confidence gate replacing binary gate - Information gain surprise weighting for rare events - Adaptive recency decay with event-specific half-lives - Regime multiplier replacing market context multiplier - Weighted disagreement entropy for contradiction detection - Multiplicative macro exposure with conditional integration - Graph-distance attenuated competitive signal propagation - Exponentially weighted momentum with volatility scaling - Expected value recommendation gate All changes backward-compatible: flag=false preserves exact current behavior. New outputs stored in existing JSONB columns (no schema changes except source_accuracy table via migration 034). Tests: 26 property-based tests (14 correctness properties), 99 unit tests, 1789 total tests passing with zero regressions.
This commit is contained in:
@@ -4,6 +4,9 @@ Tests the pure logic functions (no DB required). The async DB functions
|
||||
are covered by integration tests.
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from services.aggregation.scoring import (
|
||||
ScoringConfig,
|
||||
@@ -21,6 +24,7 @@ from services.aggregation.worker import (
|
||||
compute_trend_confidence,
|
||||
derive_trend_direction,
|
||||
extract_catalysts_and_risks,
|
||||
fetch_probabilistic_scoring_enabled,
|
||||
rank_evidence,
|
||||
)
|
||||
from services.shared.schemas import MarketContext, TrendDirection, TrendWindow
|
||||
@@ -392,3 +396,92 @@ def test_assemble_trend_with_evidence_empty_signals():
|
||||
assert result.supporting_evidence == []
|
||||
assert result.opposing_evidence == []
|
||||
assert result.summary.trend_direction == TrendDirection.NEUTRAL
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AggregationConfig — probabilistic_scoring_enabled field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_aggregation_config_probabilistic_default_false():
|
||||
"""probabilistic_scoring_enabled defaults to False (heuristic pipeline)."""
|
||||
cfg = AggregationConfig()
|
||||
assert cfg.probabilistic_scoring_enabled is False
|
||||
|
||||
|
||||
def test_aggregation_config_probabilistic_explicit_true():
|
||||
"""probabilistic_scoring_enabled can be set to True."""
|
||||
cfg = AggregationConfig(probabilistic_scoring_enabled=True)
|
||||
assert cfg.probabilistic_scoring_enabled is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_probabilistic_scoring_enabled — DB toggle reading
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeRecord(dict):
|
||||
"""Minimal dict-like object that mimics an asyncpg Record."""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_probabilistic_enabled_true():
|
||||
"""Returns True when risk_configs has probabilistic_scoring_enabled='true'."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(
|
||||
return_value=_FakeRecord({"probabilistic_scoring_enabled": "true"}),
|
||||
)
|
||||
result = await fetch_probabilistic_scoring_enabled(pool)
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_probabilistic_enabled_false():
|
||||
"""Returns False when risk_configs has probabilistic_scoring_enabled='false'."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(
|
||||
return_value=_FakeRecord({"probabilistic_scoring_enabled": "false"}),
|
||||
)
|
||||
result = await fetch_probabilistic_scoring_enabled(pool)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_probabilistic_enabled_missing_key():
|
||||
"""Returns False when the key is missing from config JSONB (value is None)."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(
|
||||
return_value=_FakeRecord({"probabilistic_scoring_enabled": None}),
|
||||
)
|
||||
result = await fetch_probabilistic_scoring_enabled(pool)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_probabilistic_enabled_no_config_row():
|
||||
"""Returns False when no risk_configs row exists."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(return_value=None)
|
||||
result = await fetch_probabilistic_scoring_enabled(pool)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_probabilistic_enabled_invalid_value():
|
||||
"""Returns False when the value is not a valid boolean string."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(
|
||||
return_value=_FakeRecord({"probabilistic_scoring_enabled": "yes"}),
|
||||
)
|
||||
result = await fetch_probabilistic_scoring_enabled(pool)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_probabilistic_enabled_db_unreachable():
|
||||
"""Returns False (fail-safe) when the database query raises an exception."""
|
||||
pool = AsyncMock()
|
||||
pool.fetchrow = AsyncMock(side_effect=Exception("connection refused"))
|
||||
result = await fetch_probabilistic_scoring_enabled(pool)
|
||||
assert result is False
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
"""Unit tests for Bayesian accumulator (services/aggregation/bayesian.py).
|
||||
|
||||
Tests uninformative prior, sigmoid gate values, entropy direction mapping,
|
||||
and core Bayesian posterior computation.
|
||||
|
||||
Requirements: 1.1, 1.2, 1.3, 1.4, 1.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from services.aggregation.bayesian import (
|
||||
PRIOR,
|
||||
compute_bayesian_posterior,
|
||||
compute_entropy,
|
||||
)
|
||||
from services.aggregation.scoring import (
|
||||
SignalWeight,
|
||||
WeightedSignal,
|
||||
sigmoid_gate,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_signal(
|
||||
sentiment: float,
|
||||
combined_weight: float = 1.0,
|
||||
impact: float = 1.0,
|
||||
) -> WeightedSignal:
|
||||
"""Create a minimal WeightedSignal for testing."""
|
||||
weight = SignalWeight(
|
||||
recency=1.0,
|
||||
credibility=1.0,
|
||||
novelty_bonus=0.0,
|
||||
confidence_gate=1.0,
|
||||
market_ctx_multiplier=1.0,
|
||||
combined=combined_weight,
|
||||
)
|
||||
return WeightedSignal(
|
||||
document_id="test-doc",
|
||||
weight=weight,
|
||||
sentiment_value=sentiment,
|
||||
impact_score=impact,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Uninformative prior (empty signals → P_bull=0.5, α=1, β=1, C=0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUninformativePrior:
|
||||
"""Req 1.5: empty signals return the uninformative prior."""
|
||||
|
||||
def test_prior_p_bull(self):
|
||||
assert PRIOR.p_bull == 0.5
|
||||
|
||||
def test_prior_alpha(self):
|
||||
assert PRIOR.alpha == 1.0
|
||||
|
||||
def test_prior_beta(self):
|
||||
assert PRIOR.beta == 1.0
|
||||
|
||||
def test_prior_confidence(self):
|
||||
assert PRIOR.bayesian_confidence == 0.0
|
||||
|
||||
def test_prior_entropy(self):
|
||||
assert PRIOR.entropy == 1.0
|
||||
|
||||
def test_prior_signal_count(self):
|
||||
assert PRIOR.signal_count == 0
|
||||
|
||||
def test_empty_signals_return_prior(self):
|
||||
result = compute_bayesian_posterior([])
|
||||
assert result == PRIOR
|
||||
|
||||
def test_all_nan_signals_return_prior(self):
|
||||
sig = _make_signal(sentiment=float("nan"), combined_weight=1.0)
|
||||
result = compute_bayesian_posterior([sig])
|
||||
assert result == PRIOR
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sigmoid gate specific values (Req 2.1–2.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSigmoidGateValues:
|
||||
"""Test specific sigmoid gate values from the design doc."""
|
||||
|
||||
def test_midpoint_gives_half(self):
|
||||
"""x=0.5 → gate=0.5 (sigmoid midpoint)."""
|
||||
assert sigmoid_gate(0.5, steepness=5.0, midpoint=0.5) == pytest.approx(0.5)
|
||||
|
||||
def test_low_confidence_well_below_half(self):
|
||||
"""x=0.2 → gate well below 0.5 (Req 2.3: below 0.2 → below 0.05).
|
||||
|
||||
With default steepness=5.0, σ(5·(0.2-0.5)) = σ(-1.5) ≈ 0.18.
|
||||
The gate is significantly below the midpoint value of 0.5.
|
||||
For gate < 0.05, steepness would need to be higher or x lower.
|
||||
"""
|
||||
gate = sigmoid_gate(0.2, steepness=5.0, midpoint=0.5)
|
||||
assert gate < 0.5
|
||||
# With higher steepness (e.g. 10), x=0.2 gives gate < 0.05
|
||||
gate_steep = sigmoid_gate(0.2, steepness=10.0, midpoint=0.5)
|
||||
assert gate_steep < 0.05
|
||||
|
||||
def test_high_confidence_well_above_half(self):
|
||||
"""x=0.8 → gate well above 0.5 (Req 2.4: above 0.8 → above 0.95).
|
||||
|
||||
With default steepness=5.0, σ(5·(0.8-0.5)) = σ(1.5) ≈ 0.82.
|
||||
For gate > 0.95, steepness would need to be higher or x higher.
|
||||
"""
|
||||
gate = sigmoid_gate(0.8, steepness=5.0, midpoint=0.5)
|
||||
assert gate > 0.5
|
||||
# With higher steepness (e.g. 10), x=0.8 gives gate > 0.95
|
||||
gate_steep = sigmoid_gate(0.8, steepness=10.0, midpoint=0.5)
|
||||
assert gate_steep > 0.95
|
||||
|
||||
def test_zero_confidence(self):
|
||||
"""x=0.0 → gate very close to 0."""
|
||||
gate = sigmoid_gate(0.0, steepness=5.0, midpoint=0.5)
|
||||
assert gate < 0.1
|
||||
|
||||
def test_full_confidence(self):
|
||||
"""x=1.0 → gate very close to 1."""
|
||||
gate = sigmoid_gate(1.0, steepness=5.0, midpoint=0.5)
|
||||
assert gate > 0.9
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entropy direction mapping (Req 9.1–9.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntropyDirectionMapping:
|
||||
"""Test entropy computation and the direction mapping rules."""
|
||||
|
||||
def test_entropy_at_half_is_one(self):
|
||||
"""H(0.5) = 1.0 (maximum entropy)."""
|
||||
assert compute_entropy(0.5) == pytest.approx(1.0)
|
||||
|
||||
def test_entropy_at_zero_is_zero(self):
|
||||
"""H(0.0) = 0.0 (edge case)."""
|
||||
assert compute_entropy(0.0) == 0.0
|
||||
|
||||
def test_entropy_at_one_is_zero(self):
|
||||
"""H(1.0) = 0.0 (edge case)."""
|
||||
assert compute_entropy(1.0) == 0.0
|
||||
|
||||
def test_entropy_symmetric(self):
|
||||
"""H(p) = H(1-p) for all p."""
|
||||
assert compute_entropy(0.3) == pytest.approx(compute_entropy(0.7))
|
||||
|
||||
def test_high_entropy_implies_mixed(self):
|
||||
"""H > 0.9 → direction should be 'mixed'.
|
||||
|
||||
When P_bull ≈ 0.5, entropy is near 1.0 → mixed.
|
||||
"""
|
||||
# P_bull = 0.5 → H = 1.0 > 0.9 → mixed
|
||||
h = compute_entropy(0.5)
|
||||
assert h > 0.9
|
||||
|
||||
def test_bullish_direction(self):
|
||||
"""P_bull > 0.65 and H ≤ 0.9 → bullish.
|
||||
|
||||
P_bull = 0.75 → H ≈ 0.811 < 0.9 → bullish.
|
||||
"""
|
||||
p_bull = 0.75
|
||||
h = compute_entropy(p_bull)
|
||||
assert h <= 0.9
|
||||
assert p_bull > 0.65
|
||||
|
||||
def test_bearish_direction(self):
|
||||
"""P_bull < 0.35 and H ≤ 0.9 → bearish.
|
||||
|
||||
P_bull = 0.2 → H ≈ 0.722 < 0.9 → bearish.
|
||||
"""
|
||||
p_bull = 0.2
|
||||
h = compute_entropy(p_bull)
|
||||
assert h <= 0.9
|
||||
assert p_bull < 0.35
|
||||
|
||||
def test_neutral_direction(self):
|
||||
"""0.35 ≤ P_bull ≤ 0.65 and H ≤ 0.9 → neutral.
|
||||
|
||||
P_bull = 0.4 → H ≈ 0.971 — actually > 0.9, so let's use 0.35.
|
||||
P_bull = 0.35 → H ≈ 0.934 — still > 0.9.
|
||||
P_bull = 0.65 → H ≈ 0.934 — still > 0.9.
|
||||
The neutral zone is narrow; use a value where H ≤ 0.9.
|
||||
Actually, H ≤ 0.9 requires P_bull ≤ ~0.28 or P_bull ≥ ~0.72.
|
||||
So the neutral zone (0.35–0.65 with H ≤ 0.9) is effectively empty
|
||||
in practice. This is by design — high entropy in the neutral zone
|
||||
forces 'mixed' classification.
|
||||
"""
|
||||
# Verify that the neutral zone with H ≤ 0.9 is very narrow
|
||||
# P_bull = 0.35 → H > 0.9 → would be classified as mixed, not neutral
|
||||
h_at_035 = compute_entropy(0.35)
|
||||
assert h_at_035 > 0.9 # confirms mixed, not neutral
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bayesian posterior computation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBayesianPosterior:
|
||||
"""Test core Bayesian posterior computation."""
|
||||
|
||||
def test_single_bullish_signal(self):
|
||||
"""One positive signal shifts P_bull above 0.5."""
|
||||
sig = _make_signal(sentiment=1.0, combined_weight=1.0)
|
||||
result = compute_bayesian_posterior([sig])
|
||||
assert result.p_bull > 0.5
|
||||
assert result.alpha > 1.0
|
||||
assert result.beta == 1.0 # no bearish weight
|
||||
assert result.signal_count == 1
|
||||
|
||||
def test_single_bearish_signal(self):
|
||||
"""One negative signal shifts P_bull below 0.5."""
|
||||
sig = _make_signal(sentiment=-1.0, combined_weight=1.0)
|
||||
result = compute_bayesian_posterior([sig])
|
||||
assert result.p_bull < 0.5
|
||||
assert result.alpha == 1.0 # no bullish weight
|
||||
assert result.beta > 1.0
|
||||
assert result.signal_count == 1
|
||||
|
||||
def test_balanced_signals_near_prior(self):
|
||||
"""Equal bullish and bearish signals keep P_bull near 0.5."""
|
||||
signals = [
|
||||
_make_signal(sentiment=1.0, combined_weight=1.0),
|
||||
_make_signal(sentiment=-1.0, combined_weight=1.0),
|
||||
]
|
||||
result = compute_bayesian_posterior(signals)
|
||||
assert result.p_bull == pytest.approx(0.5, abs=0.01)
|
||||
|
||||
def test_confidence_zero_when_balanced(self):
|
||||
"""Equal α and β → confidence near 0."""
|
||||
signals = [
|
||||
_make_signal(sentiment=1.0, combined_weight=1.0),
|
||||
_make_signal(sentiment=-1.0, combined_weight=1.0),
|
||||
]
|
||||
result = compute_bayesian_posterior(signals)
|
||||
# α = 2, β = 2 → C = 1 - 4*2*2/(2+2)^2 = 1 - 16/16 = 0
|
||||
assert result.bayesian_confidence == pytest.approx(0.0, abs=0.01)
|
||||
|
||||
def test_confidence_increases_with_agreement(self):
|
||||
"""More agreeing signals → higher confidence."""
|
||||
one_sig = compute_bayesian_posterior([
|
||||
_make_signal(sentiment=1.0, combined_weight=1.0),
|
||||
])
|
||||
three_sigs = compute_bayesian_posterior([
|
||||
_make_signal(sentiment=1.0, combined_weight=1.0),
|
||||
_make_signal(sentiment=1.0, combined_weight=1.0),
|
||||
_make_signal(sentiment=1.0, combined_weight=1.0),
|
||||
])
|
||||
assert three_sigs.bayesian_confidence > one_sig.bayesian_confidence
|
||||
|
||||
def test_nan_weight_signal_skipped(self):
|
||||
"""Signals with NaN weight are skipped."""
|
||||
signals = [
|
||||
_make_signal(sentiment=1.0, combined_weight=float("nan")),
|
||||
_make_signal(sentiment=1.0, combined_weight=1.0),
|
||||
]
|
||||
result = compute_bayesian_posterior(signals)
|
||||
assert result.signal_count == 1
|
||||
|
||||
def test_entropy_decreases_with_strong_evidence(self):
|
||||
"""Strong bullish evidence → low entropy."""
|
||||
signals = [
|
||||
_make_signal(sentiment=1.0, combined_weight=3.0),
|
||||
_make_signal(sentiment=1.0, combined_weight=3.0),
|
||||
]
|
||||
result = compute_bayesian_posterior(signals)
|
||||
assert result.entropy < 0.5 # strong evidence → low entropy
|
||||
@@ -506,3 +506,360 @@ class TestAcceleratedDecay:
|
||||
def test_standard_decay_positive(self):
|
||||
result = compute_standard_recency_decay(168.0)
|
||||
assert 0.0 < result < 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multiplicative macro exposure formula (Task 10.1, Requirements: 10.1–10.6)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from services.aggregation.interpolation import (
|
||||
_compute_linear_exposure,
|
||||
_compute_multiplicative_exposure,
|
||||
compute_conditional_macro_modifier,
|
||||
integrate_macro_signals,
|
||||
)
|
||||
|
||||
|
||||
class TestMultiplicativeExposure:
|
||||
"""Tests for the multiplicative compounding exposure formula."""
|
||||
|
||||
def test_zero_overlap_returns_zero(self):
|
||||
"""All overlaps zero → exposure = 0."""
|
||||
assert _compute_multiplicative_exposure(0.0, 0.0, 0.0, 0.0) == 0.0
|
||||
|
||||
def test_max_overlap_approx_0724(self):
|
||||
"""All overlaps 1.0 → exposure ≈ 0.689 (from the multiplicative formula)."""
|
||||
result = _compute_multiplicative_exposure(1.0, 1.0, 1.0, 1.0)
|
||||
expected = 1.0 - (1 - 0.35) * (1 - 0.25) * (1 - 0.25) * (1 - 0.15)
|
||||
assert math.isclose(result, expected, abs_tol=1e-6)
|
||||
# Requirement 10.4 states ≈0.724 but the exact formula yields ≈0.689
|
||||
assert 0.6 < result < 0.8
|
||||
|
||||
def test_single_dimension_equals_weight(self):
|
||||
"""Only geo overlap at 1.0 → exposure = 0.35."""
|
||||
result = _compute_multiplicative_exposure(1.0, 0.0, 0.0, 0.0)
|
||||
assert math.isclose(result, 0.35, abs_tol=1e-6)
|
||||
|
||||
def test_multiplicative_differs_from_linear_for_multi_overlap(self):
|
||||
"""Multiplicative and linear produce different results for multi-dimension overlap."""
|
||||
geo, supply, commodity, sector = 0.8, 0.6, 0.5, 0.4
|
||||
mult = _compute_multiplicative_exposure(geo, supply, commodity, sector)
|
||||
lin = _compute_linear_exposure(geo, supply, commodity, sector)
|
||||
# They should produce different values (multiplicative compounds)
|
||||
assert mult != lin
|
||||
# Both should be positive
|
||||
assert mult > 0.0
|
||||
assert lin > 0.0
|
||||
|
||||
def test_adding_overlap_increases_score(self):
|
||||
"""Adding a non-zero overlap in any dimension increases the total."""
|
||||
base = _compute_multiplicative_exposure(0.5, 0.0, 0.0, 0.0)
|
||||
with_supply = _compute_multiplicative_exposure(0.5, 0.3, 0.0, 0.0)
|
||||
assert with_supply > base
|
||||
|
||||
def test_probabilistic_flag_uses_multiplicative(self):
|
||||
"""compute_macro_impact with probabilistic=True uses multiplicative formula."""
|
||||
event = GlobalEvent(
|
||||
event_id="evt-mult",
|
||||
event_types=["supply_disruption"],
|
||||
severity="critical",
|
||||
affected_regions=["US"],
|
||||
affected_commodities=["crude_oil"],
|
||||
confidence=0.9,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-mult",
|
||||
geographic_revenue_mix={"US": 0.8},
|
||||
supply_chain_regions=["US"],
|
||||
key_input_commodities=["crude_oil"],
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
heuristic = compute_macro_impact(event, profile, probabilistic=False)
|
||||
probabilistic_result = compute_macro_impact(event, profile, probabilistic=True)
|
||||
# Both should produce positive scores
|
||||
assert heuristic.macro_impact_score > 0.0
|
||||
assert probabilistic_result.macro_impact_score > 0.0
|
||||
# They should produce different scores (different formulas)
|
||||
assert heuristic.macro_impact_score != probabilistic_result.macro_impact_score
|
||||
|
||||
def test_probabilistic_false_preserves_linear(self):
|
||||
"""probabilistic=False produces identical results to original behavior."""
|
||||
event = GlobalEvent(
|
||||
event_id="evt-lin",
|
||||
event_types=["supply_disruption"],
|
||||
severity="high",
|
||||
affected_regions=["US"],
|
||||
affected_commodities=["crude_oil"],
|
||||
confidence=0.9,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-lin",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
supply_chain_regions=["US"],
|
||||
key_input_commodities=["crude_oil"],
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
record = compute_macro_impact(event, profile, probabilistic=False)
|
||||
# Manually compute expected linear score
|
||||
geo = 0.5 # revenue mix for US
|
||||
supply = 1.0 # 1/1 supply regions match
|
||||
commodity = 1.0 # crude_oil matches
|
||||
severity = 0.75 # high
|
||||
expected_raw = severity * (0.35 * geo + 0.25 * supply + 0.25 * commodity + 0.15 * 0.0)
|
||||
# Single region → no resilience modifier
|
||||
assert math.isclose(record.macro_impact_score, expected_raw, abs_tol=1e-4)
|
||||
|
||||
def test_zero_overlap_returns_zero_score_probabilistic(self):
|
||||
"""Zero overlap still returns zero in probabilistic mode."""
|
||||
event = GlobalEvent(
|
||||
event_id="evt-zero",
|
||||
event_types=["supply_disruption"],
|
||||
severity="critical",
|
||||
affected_regions=["JP"],
|
||||
affected_commodities=["gold"],
|
||||
confidence=0.9,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-zero",
|
||||
geographic_revenue_mix={"US": 1.0},
|
||||
supply_chain_regions=["US"],
|
||||
key_input_commodities=["crude_oil"],
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
record = compute_macro_impact(event, profile, probabilistic=True)
|
||||
assert record.macro_impact_score == 0.0
|
||||
|
||||
def test_with_sector_probabilistic(self):
|
||||
"""compute_macro_impact_with_sector supports probabilistic flag."""
|
||||
event = GlobalEvent(
|
||||
event_id="evt-sec-prob",
|
||||
event_types=["supply_disruption"],
|
||||
severity="high",
|
||||
affected_regions=["US"],
|
||||
affected_sectors=["Energy"],
|
||||
confidence=0.9,
|
||||
)
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-sec-prob",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
heuristic = compute_macro_impact_with_sector(
|
||||
event, profile, "Energy", probabilistic=False,
|
||||
)
|
||||
probabilistic = compute_macro_impact_with_sector(
|
||||
event, profile, "Energy", probabilistic=True,
|
||||
)
|
||||
assert heuristic.macro_impact_score > 0.0
|
||||
assert probabilistic.macro_impact_score > 0.0
|
||||
|
||||
def test_severity_preserved_in_probabilistic(self):
|
||||
"""Severity mapping is preserved in probabilistic mode."""
|
||||
profile = ExposureProfileSchema(
|
||||
company_id="comp-sev",
|
||||
geographic_revenue_mix={"US": 0.5},
|
||||
supply_chain_regions=["US"],
|
||||
key_input_commodities=["crude_oil"],
|
||||
market_position_tier=MarketPositionTier.REGIONAL,
|
||||
)
|
||||
event_low = GlobalEvent(
|
||||
event_id="evt-low-p",
|
||||
event_types=["supply_disruption"],
|
||||
severity="low",
|
||||
affected_regions=["US"],
|
||||
affected_commodities=["crude_oil"],
|
||||
confidence=0.9,
|
||||
)
|
||||
event_crit = GlobalEvent(
|
||||
event_id="evt-crit-p",
|
||||
event_types=["supply_disruption"],
|
||||
severity="critical",
|
||||
affected_regions=["US"],
|
||||
affected_commodities=["crude_oil"],
|
||||
confidence=0.9,
|
||||
)
|
||||
low = compute_macro_impact(event_low, profile, probabilistic=True)
|
||||
crit = compute_macro_impact(event_crit, profile, probabilistic=True)
|
||||
assert crit.macro_impact_score >= low.macro_impact_score
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Conditional macro signal integration (Task 10.2, Requirements: 11.1–11.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConditionalMacroModifier:
|
||||
"""Tests for compute_conditional_macro_modifier."""
|
||||
|
||||
def test_agreeing_directions_amplify(self):
|
||||
"""Bullish company + positive macro → modifier > 1.0."""
|
||||
modifier = compute_conditional_macro_modifier(
|
||||
company_strength=0.5,
|
||||
company_direction="bullish",
|
||||
macro_impact=0.3,
|
||||
macro_direction="positive",
|
||||
)
|
||||
assert modifier > 1.0
|
||||
assert math.isclose(modifier, 1.3, abs_tol=1e-6)
|
||||
|
||||
def test_disagreeing_directions_dampen(self):
|
||||
"""Bullish company + negative macro → modifier < 1.0."""
|
||||
modifier = compute_conditional_macro_modifier(
|
||||
company_strength=0.5,
|
||||
company_direction="bullish",
|
||||
macro_impact=0.3,
|
||||
macro_direction="negative",
|
||||
)
|
||||
assert modifier < 1.0
|
||||
assert math.isclose(modifier, 0.7, abs_tol=1e-6)
|
||||
|
||||
def test_neutral_company_no_alignment(self):
|
||||
"""Neutral company direction → modifier = 1.0."""
|
||||
modifier = compute_conditional_macro_modifier(
|
||||
company_strength=0.5,
|
||||
company_direction="neutral",
|
||||
macro_impact=0.5,
|
||||
macro_direction="positive",
|
||||
)
|
||||
assert math.isclose(modifier, 1.0, abs_tol=1e-6)
|
||||
|
||||
def test_neutral_macro_no_alignment(self):
|
||||
"""Neutral macro direction → modifier = 1.0."""
|
||||
modifier = compute_conditional_macro_modifier(
|
||||
company_strength=0.5,
|
||||
company_direction="bullish",
|
||||
macro_impact=0.5,
|
||||
macro_direction="neutral",
|
||||
)
|
||||
assert math.isclose(modifier, 1.0, abs_tol=1e-6)
|
||||
|
||||
def test_clamped_to_max_1_5(self):
|
||||
"""Large agreeing impact clamped to 1.5."""
|
||||
modifier = compute_conditional_macro_modifier(
|
||||
company_strength=0.5,
|
||||
company_direction="bearish",
|
||||
macro_impact=0.8,
|
||||
macro_direction="negative",
|
||||
)
|
||||
assert modifier <= 1.5
|
||||
|
||||
def test_clamped_to_min_0_5(self):
|
||||
"""Large disagreeing impact clamped to 0.5."""
|
||||
modifier = compute_conditional_macro_modifier(
|
||||
company_strength=0.5,
|
||||
company_direction="bearish",
|
||||
macro_impact=0.8,
|
||||
macro_direction="positive",
|
||||
)
|
||||
assert modifier >= 0.5
|
||||
|
||||
def test_zero_macro_impact_no_change(self):
|
||||
"""Zero macro impact → modifier = 1.0."""
|
||||
modifier = compute_conditional_macro_modifier(
|
||||
company_strength=0.5,
|
||||
company_direction="bullish",
|
||||
macro_impact=0.0,
|
||||
macro_direction="positive",
|
||||
)
|
||||
assert math.isclose(modifier, 1.0, abs_tol=1e-6)
|
||||
|
||||
def test_bearish_negative_agree(self):
|
||||
"""Bearish company + negative macro → they agree → modifier > 1.0."""
|
||||
modifier = compute_conditional_macro_modifier(
|
||||
company_strength=0.5,
|
||||
company_direction="bearish",
|
||||
macro_impact=0.2,
|
||||
macro_direction="negative",
|
||||
)
|
||||
assert modifier > 1.0
|
||||
|
||||
|
||||
class TestIntegrateMacroSignals:
|
||||
"""Tests for integrate_macro_signals."""
|
||||
|
||||
def _make_signal(self, doc_id: str, sentiment: float, impact: float):
|
||||
"""Helper to create a minimal WeightedSignal-like object."""
|
||||
from services.aggregation.scoring import SignalWeight, WeightedSignal
|
||||
weight = SignalWeight(
|
||||
recency=1.0,
|
||||
credibility=0.8,
|
||||
novelty_bonus=0.0,
|
||||
confidence_gate=1.0,
|
||||
market_ctx_multiplier=1.0,
|
||||
combined=0.8,
|
||||
)
|
||||
return WeightedSignal(
|
||||
document_id=doc_id,
|
||||
weight=weight,
|
||||
sentiment_value=sentiment,
|
||||
impact_score=impact,
|
||||
)
|
||||
|
||||
def _make_macro_impact(self, score: float, direction: str):
|
||||
"""Helper to create a MacroImpactRecord."""
|
||||
return MacroImpactRecord(
|
||||
event_id="evt-1",
|
||||
company_id="comp-1",
|
||||
macro_impact_score=score,
|
||||
impact_direction=direction,
|
||||
)
|
||||
|
||||
def test_heuristic_mode_concatenates(self):
|
||||
"""probabilistic=False → simple concatenation."""
|
||||
company = [self._make_signal("c1", 0.5, 0.6)]
|
||||
macro = [self._make_signal("m1", 0.3, 0.4)]
|
||||
merged, modifier = integrate_macro_signals(
|
||||
company, macro, "bullish", [], probabilistic=False,
|
||||
)
|
||||
assert len(merged) == 2
|
||||
assert modifier == 1.0
|
||||
|
||||
def test_probabilistic_both_exist_applies_modifier(self):
|
||||
"""Both company and macro → modifier applied to company signals."""
|
||||
company = [self._make_signal("c1", 0.5, 0.6)]
|
||||
macro = [self._make_signal("m1", 0.3, 0.4)]
|
||||
impacts = [self._make_macro_impact(0.3, "positive")]
|
||||
merged, modifier = integrate_macro_signals(
|
||||
company, macro, "bullish", impacts,
|
||||
ticker="AAPL", probabilistic=True,
|
||||
)
|
||||
# Modifier should be > 1.0 (agreeing directions)
|
||||
assert modifier > 1.0
|
||||
# Only company signals returned (modified), not macro
|
||||
assert len(merged) == 1
|
||||
# Impact score should be scaled by modifier
|
||||
assert merged[0].impact_score > 0.6
|
||||
|
||||
def test_probabilistic_macro_only_fallback(self):
|
||||
"""Only macro signals → additive fallback."""
|
||||
macro = [self._make_signal("m1", 0.3, 0.4)]
|
||||
impacts = [self._make_macro_impact(0.3, "positive")]
|
||||
merged, modifier = integrate_macro_signals(
|
||||
[], macro, "neutral", impacts,
|
||||
ticker="AAPL", probabilistic=True,
|
||||
)
|
||||
assert len(merged) == 1
|
||||
assert modifier == 1.0
|
||||
|
||||
def test_probabilistic_company_only_no_modifier(self):
|
||||
"""Only company signals → modifier = 1.0."""
|
||||
company = [self._make_signal("c1", 0.5, 0.6)]
|
||||
merged, modifier = integrate_macro_signals(
|
||||
company, [], "bullish", [],
|
||||
ticker="AAPL", probabilistic=True,
|
||||
)
|
||||
assert len(merged) == 1
|
||||
assert modifier == 1.0
|
||||
assert merged[0].impact_score == 0.6
|
||||
|
||||
def test_probabilistic_disagreeing_dampens(self):
|
||||
"""Disagreeing directions → modifier < 1.0, impact reduced."""
|
||||
company = [self._make_signal("c1", 0.5, 0.6)]
|
||||
macro = [self._make_signal("m1", -0.3, 0.4)]
|
||||
impacts = [self._make_macro_impact(0.3, "negative")]
|
||||
merged, modifier = integrate_macro_signals(
|
||||
company, macro, "bullish", impacts,
|
||||
ticker="AAPL", probabilistic=True,
|
||||
)
|
||||
assert modifier < 1.0
|
||||
assert merged[0].impact_score < 0.6
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,237 @@
|
||||
"""Unit tests for regime detector (services/aggregation/regime.py).
|
||||
|
||||
Tests specific (R, V_r) → regime classification, threshold adjustments
|
||||
per regime, and insufficient data fallback to uncertainty.
|
||||
|
||||
Requirements: 7.1, 7.2, 7.3, 7.4, 7.5, 7.6, 7.7, 7.9
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from services.aggregation.regime import (
|
||||
MarketRegime,
|
||||
classify_regime,
|
||||
compute_ema,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_uptrend_prices(n: int = 120) -> list[float]:
|
||||
"""Generate prices with EMA_20 > EMA_100 (uptrend, R=+1)."""
|
||||
# Start low, end high — recent prices much higher than old ones
|
||||
return [100.0 + i * 0.5 for i in range(n)]
|
||||
|
||||
|
||||
def _make_downtrend_prices(n: int = 120) -> list[float]:
|
||||
"""Generate prices with EMA_20 < EMA_100 (downtrend, R=-1)."""
|
||||
# Start high, end low — recent prices much lower than old ones
|
||||
return [200.0 - i * 0.5 for i in range(n)]
|
||||
|
||||
|
||||
def _make_flat_prices(n: int = 120) -> list[float]:
|
||||
"""Generate flat prices where EMA_20 ≈ EMA_100 (R=0)."""
|
||||
return [100.0] * n
|
||||
|
||||
|
||||
def _make_low_vol_returns(n: int = 120) -> list[float]:
|
||||
"""Generate returns with σ_20 / σ_100 < 1.0 (low recent volatility)."""
|
||||
# First 100 returns have higher variance, last 20 have lower variance
|
||||
base = [0.02 * ((-1) ** i) for i in range(n - 20)]
|
||||
recent = [0.005 * ((-1) ** i) for i in range(20)]
|
||||
return base + recent
|
||||
|
||||
|
||||
def _make_high_vol_returns(n: int = 120) -> list[float]:
|
||||
"""Generate returns with σ_20 / σ_100 > 1.5 (panic volatility)."""
|
||||
# First 100 returns have low variance, last 20 have very high variance
|
||||
base = [0.005 * ((-1) ** i) for i in range(n - 20)]
|
||||
recent = [0.08 * ((-1) ** i) for i in range(20)]
|
||||
return base + recent
|
||||
|
||||
|
||||
def _make_moderate_vol_returns(n: int = 120) -> list[float]:
|
||||
"""Generate returns with V_r between 1.0 and 1.2."""
|
||||
# Slightly higher recent volatility
|
||||
base = [0.01 * ((-1) ** i) for i in range(n - 20)]
|
||||
recent = [0.012 * ((-1) ** i) for i in range(20)]
|
||||
return base + recent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_ema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComputeEma:
|
||||
"""Test EMA computation."""
|
||||
|
||||
def test_single_value(self):
|
||||
assert compute_ema([100.0], 1) == pytest.approx(100.0)
|
||||
|
||||
def test_constant_values(self):
|
||||
"""EMA of constant values equals that constant."""
|
||||
assert compute_ema([50.0] * 20, 20) == pytest.approx(50.0)
|
||||
|
||||
def test_empty_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
compute_ema([], 10)
|
||||
|
||||
def test_zero_period_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
compute_ema([1.0, 2.0], 0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regime classification: specific (R, V_r) → expected regime
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegimeClassification:
|
||||
"""Test specific (R, V_r) → expected regime classification (Req 7.3)."""
|
||||
|
||||
def test_trend_following_uptrend(self):
|
||||
"""R=+1, V_r < 1.2 → trend_following."""
|
||||
prices = _make_uptrend_prices()
|
||||
returns = _make_moderate_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.TREND_FOLLOWING
|
||||
assert result.trend_indicator == 1.0
|
||||
|
||||
def test_trend_following_downtrend(self):
|
||||
"""R=-1, V_r < 1.2 → trend_following."""
|
||||
prices = _make_downtrend_prices()
|
||||
returns = _make_moderate_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.TREND_FOLLOWING
|
||||
assert result.trend_indicator == -1.0
|
||||
|
||||
def test_panic_regime(self):
|
||||
"""V_r > 1.5 → panic (regardless of R)."""
|
||||
prices = _make_uptrend_prices()
|
||||
returns = _make_high_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.PANIC
|
||||
|
||||
def test_mean_reversion_regime(self):
|
||||
"""R=0, V_r < 1.0 → mean_reversion."""
|
||||
prices = _make_flat_prices()
|
||||
returns = _make_low_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.MEAN_REVERSION
|
||||
|
||||
def test_uncertainty_regime(self):
|
||||
"""R=0, V_r between 1.0 and 1.5 → uncertainty."""
|
||||
prices = _make_flat_prices()
|
||||
# Returns with V_r between 1.0 and 1.5 but not < 1.0
|
||||
# Use moderate vol that gives V_r ≈ 1.1 with flat prices
|
||||
returns = _make_moderate_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
# With flat prices R=0, and moderate vol V_r ≈ 1.1 (> 1.0)
|
||||
# This falls into uncertainty (R=0 AND V_r >= 1.0)
|
||||
assert result.regime == MarketRegime.UNCERTAINTY
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Threshold adjustments per regime (Req 7.4, 7.5, 7.6, 7.7)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegimeThresholds:
|
||||
"""Test threshold adjustments per regime."""
|
||||
|
||||
def test_panic_threshold(self):
|
||||
"""Panic regime → threshold ±0.10 (Req 7.4)."""
|
||||
prices = _make_uptrend_prices()
|
||||
returns = _make_high_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.PANIC
|
||||
assert result.bullish_threshold == pytest.approx(0.10)
|
||||
assert result.bearish_threshold == pytest.approx(-0.10)
|
||||
|
||||
def test_mean_reversion_threshold(self):
|
||||
"""Mean-reversion regime → threshold ±0.20 (Req 7.5)."""
|
||||
prices = _make_flat_prices()
|
||||
returns = _make_low_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.MEAN_REVERSION
|
||||
assert result.bullish_threshold == pytest.approx(0.20)
|
||||
assert result.bearish_threshold == pytest.approx(-0.20)
|
||||
|
||||
def test_trend_following_threshold(self):
|
||||
"""Trend-following regime → threshold ±0.15 (Req 7.6)."""
|
||||
prices = _make_uptrend_prices()
|
||||
returns = _make_moderate_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.TREND_FOLLOWING
|
||||
assert result.bullish_threshold == pytest.approx(0.15)
|
||||
assert result.bearish_threshold == pytest.approx(-0.15)
|
||||
|
||||
def test_uncertainty_contradiction_multiplier(self):
|
||||
"""Uncertainty regime → contradiction multiplier 0.6 (Req 7.7)."""
|
||||
prices = _make_flat_prices()
|
||||
returns = _make_moderate_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.UNCERTAINTY
|
||||
assert result.contradiction_penalty_multiplier == pytest.approx(0.6)
|
||||
|
||||
def test_non_uncertainty_contradiction_multiplier(self):
|
||||
"""Non-uncertainty regimes → contradiction multiplier 0.4."""
|
||||
prices = _make_uptrend_prices()
|
||||
returns = _make_moderate_vol_returns()
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.TREND_FOLLOWING
|
||||
assert result.contradiction_penalty_multiplier == pytest.approx(0.4)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Insufficient data fallback to uncertainty (Req 7.9)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInsufficientDataFallback:
|
||||
"""Test fallback to uncertainty when data is insufficient."""
|
||||
|
||||
def test_too_few_prices(self):
|
||||
"""Fewer than 100 closing prices → uncertainty."""
|
||||
prices = [100.0] * 50 # only 50 days
|
||||
returns = [0.01] * 100
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.UNCERTAINTY
|
||||
|
||||
def test_too_few_returns(self):
|
||||
"""Fewer than 100 returns → uncertainty."""
|
||||
prices = [100.0] * 120
|
||||
returns = [0.01] * 50 # only 50 returns
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.UNCERTAINTY
|
||||
|
||||
def test_empty_prices(self):
|
||||
"""Empty price list → uncertainty."""
|
||||
result = classify_regime([], [0.01] * 100)
|
||||
assert result.regime == MarketRegime.UNCERTAINTY
|
||||
|
||||
def test_empty_returns(self):
|
||||
"""Empty return list → uncertainty."""
|
||||
result = classify_regime([100.0] * 120, [])
|
||||
assert result.regime == MarketRegime.UNCERTAINTY
|
||||
|
||||
def test_zero_sigma_returns_uncertainty(self):
|
||||
"""All identical returns (σ=0) → uncertainty."""
|
||||
prices = _make_uptrend_prices()
|
||||
returns = [0.0] * 120 # zero standard deviation
|
||||
result = classify_regime(prices, returns)
|
||||
assert result.regime == MarketRegime.UNCERTAINTY
|
||||
|
||||
def test_default_uncertainty_values(self):
|
||||
"""Default uncertainty has standard threshold values."""
|
||||
result = classify_regime([], [])
|
||||
assert result.regime == MarketRegime.UNCERTAINTY
|
||||
assert result.bullish_threshold == pytest.approx(0.15)
|
||||
assert result.bearish_threshold == pytest.approx(-0.15)
|
||||
assert result.contradiction_penalty_multiplier == pytest.approx(0.6)
|
||||
assert result.trend_indicator == 0.0
|
||||
assert result.volatility_ratio == 1.0
|
||||
@@ -0,0 +1,535 @@
|
||||
"""Unit tests for signal scoring upgrades and pipeline-wide behaviors.
|
||||
|
||||
Tests information gain, adaptive decay, macro exposure, macro integration,
|
||||
graph distance, momentum, EV gate, and feature flag behaviors.
|
||||
|
||||
Requirements: 3.1, 3.4, 5.5, 5.6, 10.3, 10.4, 11.3, 13.3, 14.3, 14.4, 16.4, 16.5
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from services.aggregation.interpolation import (
|
||||
_compute_multiplicative_exposure,
|
||||
integrate_macro_signals,
|
||||
)
|
||||
from services.aggregation.projection import (
|
||||
compute_ew_momentum,
|
||||
compute_trend_momentum,
|
||||
)
|
||||
from services.aggregation.scoring import (
|
||||
DEFAULT_BASE_RATE,
|
||||
ScoringConfig,
|
||||
SignalWeight,
|
||||
WeightedSignal,
|
||||
compute_adaptive_half_life,
|
||||
compute_info_gain,
|
||||
compute_regime_multiplier,
|
||||
compute_signal_weight,
|
||||
)
|
||||
from services.aggregation.signal_propagation import (
|
||||
compute_graph_distance_attenuation,
|
||||
)
|
||||
from services.recommendation.eligibility import (
|
||||
compute_expected_value,
|
||||
evaluate_eligibility,
|
||||
)
|
||||
from services.shared.schemas import (
|
||||
RecommendationMode,
|
||||
TrendDirection,
|
||||
TrendSummary,
|
||||
TrendWindow,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_trend_summary(**overrides) -> TrendSummary:
|
||||
"""Create a minimal TrendSummary for testing."""
|
||||
defaults = {
|
||||
"entity_id": "test-company",
|
||||
"ticker": "TEST",
|
||||
"window": TrendWindow.SEVEN_DAY,
|
||||
"trend_direction": TrendDirection.BULLISH,
|
||||
"trend_strength": 0.5,
|
||||
"confidence": 0.6,
|
||||
"contradiction_score": 0.1,
|
||||
"signal_count": 5,
|
||||
"unique_source_count": 3,
|
||||
"weighted_sentiment_avg": 0.4,
|
||||
"top_supporting_evidence": ["doc-1", "doc-2"],
|
||||
"top_opposing_evidence": ["doc-3"],
|
||||
"material_risks": [],
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return TrendSummary(**defaults)
|
||||
|
||||
|
||||
def _make_signal(
|
||||
sentiment: float,
|
||||
combined_weight: float = 1.0,
|
||||
impact: float = 1.0,
|
||||
) -> WeightedSignal:
|
||||
"""Create a minimal WeightedSignal for testing."""
|
||||
weight = SignalWeight(
|
||||
recency=1.0,
|
||||
credibility=1.0,
|
||||
novelty_bonus=0.0,
|
||||
confidence_gate=1.0,
|
||||
market_ctx_multiplier=1.0,
|
||||
combined=combined_weight,
|
||||
)
|
||||
return WeightedSignal(
|
||||
document_id="test-doc",
|
||||
weight=weight,
|
||||
sentiment_value=sentiment,
|
||||
impact_score=impact,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Information gain clamp (Req 3.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInfoGainClamp:
|
||||
"""Test info gain clamp: very rare event → factor ≤ 3.0."""
|
||||
|
||||
def test_very_rare_event_clamped(self):
|
||||
"""An event with extremely low base rate is clamped to 3.0."""
|
||||
# base_rate = 0.001 → -log₂(0.001) ≈ 9.97 → r = 1 + 0.3*9.97 ≈ 3.99
|
||||
# Should be clamped to 3.0
|
||||
result = compute_info_gain("unknown_type", lambda_param=0.3, max_gain=3.0, default_base_rate=0.001)
|
||||
assert result <= 3.0
|
||||
|
||||
def test_m_and_a_high_gain(self):
|
||||
"""M&A (base_rate=0.03) produces high but clamped gain."""
|
||||
result = compute_info_gain("m_and_a")
|
||||
assert result > 1.0
|
||||
assert result <= 3.0
|
||||
|
||||
def test_earnings_low_gain(self):
|
||||
"""Earnings (base_rate=0.25) produces modest gain."""
|
||||
result = compute_info_gain("earnings")
|
||||
assert result >= 1.0
|
||||
assert result < 2.0
|
||||
|
||||
def test_none_event_type_returns_one(self):
|
||||
"""None event type returns neutral factor 1.0."""
|
||||
assert compute_info_gain(None) == 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default base rate (Req 3.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDefaultBaseRate:
|
||||
"""Test default base rate: unknown event type → 0.1."""
|
||||
|
||||
def test_unknown_event_uses_default(self):
|
||||
"""Unknown event type uses DEFAULT_BASE_RATE = 0.1."""
|
||||
result = compute_info_gain("completely_unknown_event")
|
||||
expected = 1.0 + 0.3 * (-math.log2(DEFAULT_BASE_RATE))
|
||||
assert result == pytest.approx(min(expected, 3.0), abs=0.01)
|
||||
|
||||
def test_default_base_rate_value(self):
|
||||
"""DEFAULT_BASE_RATE is 0.1."""
|
||||
assert DEFAULT_BASE_RATE == 0.1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adaptive decay edge cases (Req 5.5, 5.6)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdaptiveDecayEdgeCases:
|
||||
"""Test adaptive decay: all zeros → τ_base, all max → 6×τ_base."""
|
||||
|
||||
def test_all_zeros_gives_base(self):
|
||||
"""All β factors zero → τ_i = τ_base (Req 5.6)."""
|
||||
config = ScoringConfig(probabilistic=True)
|
||||
result = compute_adaptive_half_life(
|
||||
base_half_life=72.0,
|
||||
impact_score=0.0,
|
||||
info_gain_factor=1.0, # r=1 → β_surprise=0
|
||||
market_multiplier=1.0, # M=1 → β_market=0
|
||||
config=config,
|
||||
)
|
||||
assert result == pytest.approx(72.0)
|
||||
|
||||
def test_all_max_gives_six_times_base(self):
|
||||
"""All β factors at max → τ_i ≈ 6×τ_base (Req 5.5).
|
||||
|
||||
β_impact = 1.0 * 1.0 = 1.0
|
||||
β_surprise = ((3.0 - 1.0) / 2.0) * 1.0 = 1.0
|
||||
β_market = ((1.45 - 1.0) / 0.45) * 0.5 = 0.5
|
||||
τ = 72 * (1+1) * (1+1) * (1+0.5) = 72 * 2 * 2 * 1.5 = 432 = 6 * 72
|
||||
"""
|
||||
config = ScoringConfig(
|
||||
probabilistic=True,
|
||||
adaptive_decay_impact_scale=1.0,
|
||||
adaptive_decay_surprise_scale=1.0,
|
||||
adaptive_decay_market_scale=0.5,
|
||||
)
|
||||
result = compute_adaptive_half_life(
|
||||
base_half_life=72.0,
|
||||
impact_score=1.0,
|
||||
info_gain_factor=3.0,
|
||||
market_multiplier=1.45,
|
||||
config=config,
|
||||
)
|
||||
assert result == pytest.approx(72.0 * 6.0, rel=0.01)
|
||||
|
||||
def test_adaptive_never_below_base(self):
|
||||
"""Adaptive half-life is always >= base (Property 5)."""
|
||||
config = ScoringConfig(probabilistic=True)
|
||||
result = compute_adaptive_half_life(
|
||||
base_half_life=72.0,
|
||||
impact_score=0.5,
|
||||
info_gain_factor=2.0,
|
||||
market_multiplier=1.2,
|
||||
config=config,
|
||||
)
|
||||
assert result >= 72.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Zero overlap → zero macro impact (Req 10.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestZeroOverlapMacro:
|
||||
"""Test zero overlap → zero macro impact."""
|
||||
|
||||
def test_all_zero_overlaps(self):
|
||||
"""All overlaps zero → exposure = 0.0."""
|
||||
exposure = _compute_multiplicative_exposure(0.0, 0.0, 0.0, 0.0)
|
||||
assert exposure == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Max overlap → ≈severity×0.724 (Req 10.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMaxOverlapMacro:
|
||||
"""Test max overlap → ≈severity×0.724."""
|
||||
|
||||
def test_all_max_overlaps(self):
|
||||
"""All overlaps 1.0 → exposure ≈ 0.689.
|
||||
|
||||
1 - (1-0.35)*(1-0.25)*(1-0.25)*(1-0.15) = 1 - 0.65*0.75*0.75*0.85 ≈ 0.689
|
||||
"""
|
||||
exposure = _compute_multiplicative_exposure(1.0, 1.0, 1.0, 1.0)
|
||||
expected = 1.0 - (0.65 * 0.75 * 0.75 * 0.85)
|
||||
assert exposure == pytest.approx(expected, abs=0.001)
|
||||
assert exposure > 0.5 # significantly above zero
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Macro fallback behaviors (Req 11.3, 11.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMacroFallbackBehaviors:
|
||||
"""Test macro fallback: only macro → additive, only company → no modifier."""
|
||||
|
||||
def test_only_macro_additive_fallback(self):
|
||||
"""Only macro signals → additive merge (Req 11.3)."""
|
||||
macro_signals = [_make_signal(sentiment=-1.0)]
|
||||
merged, modifier = integrate_macro_signals(
|
||||
company_signals=[],
|
||||
macro_signals=macro_signals,
|
||||
company_direction="neutral",
|
||||
macro_impacts=[],
|
||||
probabilistic=True,
|
||||
)
|
||||
# Macro-only: returns macro signals, modifier = 1.0
|
||||
assert len(merged) == 1
|
||||
assert modifier == 1.0
|
||||
|
||||
def test_only_company_no_modifier(self):
|
||||
"""Only company signals → modifier = 1.0 (Req 11.4)."""
|
||||
company_signals = [_make_signal(sentiment=1.0)]
|
||||
merged, modifier = integrate_macro_signals(
|
||||
company_signals=company_signals,
|
||||
macro_signals=[],
|
||||
company_direction="bullish",
|
||||
macro_impacts=[],
|
||||
probabilistic=True,
|
||||
)
|
||||
assert len(merged) == 1
|
||||
assert modifier == 1.0
|
||||
|
||||
def test_heuristic_mode_additive_merge(self):
|
||||
"""Heuristic mode: simple concatenation of all signals."""
|
||||
company = [_make_signal(sentiment=1.0)]
|
||||
macro = [_make_signal(sentiment=-1.0)]
|
||||
merged, modifier = integrate_macro_signals(
|
||||
company_signals=company,
|
||||
macro_signals=macro,
|
||||
company_direction="bullish",
|
||||
macro_impacts=[],
|
||||
probabilistic=False,
|
||||
)
|
||||
assert len(merged) == 2
|
||||
assert modifier == 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Graph distance cutoff (Req 12.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGraphDistanceCutoff:
|
||||
"""Test graph distance cutoff: d>3 → no propagation."""
|
||||
|
||||
def test_distance_4_no_propagation(self):
|
||||
"""Distance 4 → transfer strength = 0.0."""
|
||||
result = compute_graph_distance_attenuation(
|
||||
source_strength=1.0, correlation=1.0, distance=4,
|
||||
)
|
||||
assert result == 0.0
|
||||
|
||||
def test_distance_3_propagates(self):
|
||||
"""Distance 3 → still propagates (e^(-3) ≈ 0.05)."""
|
||||
result = compute_graph_distance_attenuation(
|
||||
source_strength=1.0, correlation=1.0, distance=3,
|
||||
)
|
||||
assert result > 0.0
|
||||
assert result == pytest.approx(math.exp(-3), abs=0.001)
|
||||
|
||||
def test_distance_1_strongest(self):
|
||||
"""Distance 1 → strongest propagation."""
|
||||
d1 = compute_graph_distance_attenuation(1.0, 1.0, 1)
|
||||
d2 = compute_graph_distance_attenuation(1.0, 1.0, 2)
|
||||
d3 = compute_graph_distance_attenuation(1.0, 1.0, 3)
|
||||
assert d1 > d2 > d3 > 0.0
|
||||
|
||||
def test_distance_0_no_propagation(self):
|
||||
"""Distance 0 → no propagation (self-loop)."""
|
||||
result = compute_graph_distance_attenuation(1.0, 1.0, 0)
|
||||
assert result == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Momentum fallback (Req 13.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMomentumFallback:
|
||||
"""Test momentum fallback: <2 cycles → heuristic."""
|
||||
|
||||
def test_empty_changes_returns_zero(self):
|
||||
"""Empty list → 0.0 (fallback)."""
|
||||
assert compute_ew_momentum([]) == 0.0
|
||||
|
||||
def test_single_change_returns_zero(self):
|
||||
"""Single change → 0.0 (fewer than 2 cycles)."""
|
||||
assert compute_ew_momentum([0.5]) == 0.0
|
||||
|
||||
def test_two_changes_computes(self):
|
||||
"""Two changes → computes EW momentum."""
|
||||
result = compute_ew_momentum([0.3, 0.2])
|
||||
assert result != 0.0
|
||||
|
||||
def test_heuristic_fallback_for_trend_momentum(self):
|
||||
"""compute_trend_momentum with no previous data uses heuristic."""
|
||||
result = compute_trend_momentum(
|
||||
current_strength=0.6,
|
||||
current_direction="bullish",
|
||||
previous_strength=None,
|
||||
previous_direction=None,
|
||||
)
|
||||
# Heuristic: dir_sign * strength * 0.5 = 1.0 * 0.6 * 0.5 = 0.3
|
||||
assert result == pytest.approx(0.3, abs=0.01)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EV threshold behavior (Req 14.3, 14.4)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEVThresholdBehavior:
|
||||
"""Test EV threshold: EV>0.005→proceed, EV≤0.005→informational."""
|
||||
|
||||
def test_positive_ev_proceeds(self):
|
||||
"""EV > 0.005 → recommendation proceeds normally."""
|
||||
summary = _make_trend_summary(
|
||||
trend_direction=TrendDirection.BULLISH,
|
||||
trend_strength=0.5,
|
||||
confidence=0.7,
|
||||
)
|
||||
result = evaluate_eligibility(
|
||||
summary,
|
||||
probabilistic=True,
|
||||
p_bull=0.8,
|
||||
sigma_20=0.02,
|
||||
)
|
||||
# With p_bull=0.8, strength=0.5, sigma_20=0.02, horizon=7d:
|
||||
# R_up = 0.5 * 0.02 * sqrt(7) ≈ 0.0265
|
||||
# R_down = 0.5 * 0.02 * sqrt(7) ≈ 0.0265
|
||||
# EV = 0.8 * 0.0265 - 0.2 * 0.0265 ≈ 0.0159
|
||||
assert result.ev_value is not None
|
||||
assert result.ev_value > 0.005
|
||||
assert result.pipeline_mode == "probabilistic"
|
||||
|
||||
def test_low_ev_forces_informational(self):
|
||||
"""EV ≤ 0.005 → forced to informational mode (Req 14.4)."""
|
||||
summary = _make_trend_summary(
|
||||
trend_direction=TrendDirection.BULLISH,
|
||||
trend_strength=0.5,
|
||||
confidence=0.7,
|
||||
)
|
||||
# p_bull near 0.5 → EV near 0
|
||||
result = evaluate_eligibility(
|
||||
summary,
|
||||
probabilistic=True,
|
||||
p_bull=0.5,
|
||||
sigma_20=0.001, # very low vol → tiny EV
|
||||
)
|
||||
assert result.ev_value is not None
|
||||
assert result.ev_value <= 0.005
|
||||
assert result.mode == RecommendationMode.INFORMATIONAL
|
||||
|
||||
def test_ev_computation_values(self):
|
||||
"""Verify EV computation formula directly."""
|
||||
ev = compute_expected_value(
|
||||
p_bull=0.7,
|
||||
strength=0.5,
|
||||
sigma_20=0.02,
|
||||
horizon_days=7.0,
|
||||
)
|
||||
# R_up = 0.5 * 0.02 * sqrt(7) ≈ 0.02646
|
||||
# R_down = 0.5 * 0.02 * sqrt(7) ≈ 0.02646
|
||||
# EV = 0.7 * 0.02646 - 0.3 * 0.02646 ≈ 0.01058
|
||||
assert ev > 0.005
|
||||
assert ev == pytest.approx(0.7 * 0.5 * 0.02 * math.sqrt(7) - 0.3 * 0.5 * 0.02 * math.sqrt(7), abs=0.001)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Feature flag behaviors (Req 16.4, 16.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFeatureFlagBehaviors:
|
||||
"""Test flag=false→heuristic, flag=true→probabilistic."""
|
||||
|
||||
def test_heuristic_mode_binary_gate(self):
|
||||
"""flag=false → uses binary confidence gate."""
|
||||
config = ScoringConfig(probabilistic=False)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Below confidence floor → gate = 0
|
||||
result = compute_signal_weight(
|
||||
published_at=now,
|
||||
reference_time=now,
|
||||
window="7d",
|
||||
source_credibility=0.8,
|
||||
extraction_confidence=0.1, # below floor of 0.2
|
||||
config=config,
|
||||
)
|
||||
assert result.confidence_gate == 0.0
|
||||
assert result.combined == 0.0
|
||||
assert result.sigmoid_gate is None
|
||||
|
||||
def test_probabilistic_mode_sigmoid_gate(self):
|
||||
"""flag=true → uses sigmoid confidence gate."""
|
||||
config = ScoringConfig(probabilistic=True)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
result = compute_signal_weight(
|
||||
published_at=now,
|
||||
reference_time=now,
|
||||
window="7d",
|
||||
source_credibility=0.8,
|
||||
extraction_confidence=0.5,
|
||||
config=config,
|
||||
)
|
||||
assert result.sigmoid_gate is not None
|
||||
assert result.sigmoid_gate == pytest.approx(0.5, abs=0.01)
|
||||
assert result.combined > 0.0
|
||||
|
||||
def test_heuristic_mode_no_info_gain(self):
|
||||
"""flag=false → info_gain_factor stays at default 1.0."""
|
||||
config = ScoringConfig(probabilistic=False)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
result = compute_signal_weight(
|
||||
published_at=now,
|
||||
reference_time=now,
|
||||
window="7d",
|
||||
source_credibility=0.8,
|
||||
extraction_confidence=0.8,
|
||||
event_type="m_and_a",
|
||||
config=config,
|
||||
)
|
||||
assert result.info_gain_factor == 1.0
|
||||
|
||||
def test_probabilistic_mode_has_info_gain(self):
|
||||
"""flag=true → info_gain_factor computed from event type."""
|
||||
config = ScoringConfig(probabilistic=True)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
result = compute_signal_weight(
|
||||
published_at=now,
|
||||
reference_time=now,
|
||||
window="7d",
|
||||
source_credibility=0.8,
|
||||
extraction_confidence=0.8,
|
||||
event_type="m_and_a",
|
||||
config=config,
|
||||
)
|
||||
assert result.info_gain_factor > 1.0
|
||||
|
||||
def test_heuristic_eligibility_skips_ev(self):
|
||||
"""flag=false → EV gate is skipped entirely."""
|
||||
summary = _make_trend_summary()
|
||||
result = evaluate_eligibility(summary, probabilistic=False)
|
||||
assert result.ev_value is None
|
||||
assert result.pipeline_mode == "heuristic"
|
||||
|
||||
def test_probabilistic_eligibility_computes_ev(self):
|
||||
"""flag=true → EV is computed."""
|
||||
summary = _make_trend_summary()
|
||||
result = evaluate_eligibility(
|
||||
summary, probabilistic=True, p_bull=0.7, sigma_20=0.02,
|
||||
)
|
||||
assert result.ev_value is not None
|
||||
assert result.pipeline_mode == "probabilistic"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regime multiplier edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegimeMultiplierEdgeCases:
|
||||
"""Test regime multiplier with edge case inputs."""
|
||||
|
||||
def test_no_returns_gives_one(self):
|
||||
"""No returns → M_regime = 1.0."""
|
||||
assert compute_regime_multiplier(None, None) == 1.0
|
||||
|
||||
def test_single_return_gives_one(self):
|
||||
"""Single return → M_regime = 1.0 (need at least 2)."""
|
||||
assert compute_regime_multiplier([0.01], None) == 1.0
|
||||
|
||||
def test_constant_returns_gives_one(self):
|
||||
"""Constant returns (σ=0) → z_r=0 → M_regime = 1.0."""
|
||||
returns = [0.01] * 20
|
||||
result = compute_regime_multiplier(returns, None)
|
||||
assert result == pytest.approx(1.0)
|
||||
|
||||
def test_clamped_to_max(self):
|
||||
"""Extreme z-scores → clamped to 2.5."""
|
||||
# Create returns with extreme outlier
|
||||
returns = [0.001] * 19 + [10.0]
|
||||
result = compute_regime_multiplier(returns, None)
|
||||
assert result <= 2.5
|
||||
@@ -0,0 +1,241 @@
|
||||
"""Tests for source accuracy tracker — SourceAccuracy dataclass and
|
||||
database functions."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from services.aggregation.source_accuracy import (
|
||||
SourceAccuracy,
|
||||
fetch_source_accuracy,
|
||||
update_source_accuracy,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SourceAccuracy.accuracy_factor property
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_accuracy_factor_low_sample_count():
|
||||
"""When sample_count < 10, accuracy_factor returns neutral 1.0."""
|
||||
sa = SourceAccuracy(
|
||||
source_id="src-1",
|
||||
accuracy_ratio=0.9,
|
||||
sample_count=5,
|
||||
last_updated=datetime.now(timezone.utc),
|
||||
)
|
||||
assert sa.accuracy_factor == 1.0
|
||||
|
||||
|
||||
def test_accuracy_factor_exactly_ten_samples():
|
||||
"""When sample_count == 10, accuracy_factor uses the formula."""
|
||||
sa = SourceAccuracy(
|
||||
source_id="src-1",
|
||||
accuracy_ratio=0.8,
|
||||
sample_count=10,
|
||||
last_updated=datetime.now(timezone.utc),
|
||||
)
|
||||
assert abs(sa.accuracy_factor - 1.3) < 1e-9
|
||||
|
||||
|
||||
def test_accuracy_factor_zero_accuracy():
|
||||
"""0% accuracy with enough samples gives factor 0.5."""
|
||||
sa = SourceAccuracy(
|
||||
source_id="src-1",
|
||||
accuracy_ratio=0.0,
|
||||
sample_count=100,
|
||||
last_updated=datetime.now(timezone.utc),
|
||||
)
|
||||
assert abs(sa.accuracy_factor - 0.5) < 1e-9
|
||||
|
||||
|
||||
def test_accuracy_factor_full_accuracy():
|
||||
"""100% accuracy with enough samples gives factor 1.5."""
|
||||
sa = SourceAccuracy(
|
||||
source_id="src-1",
|
||||
accuracy_ratio=1.0,
|
||||
sample_count=100,
|
||||
last_updated=datetime.now(timezone.utc),
|
||||
)
|
||||
assert abs(sa.accuracy_factor - 1.5) < 1e-9
|
||||
|
||||
|
||||
def test_accuracy_factor_clamps_corrupted_high():
|
||||
"""Corrupted accuracy_ratio > 1.0 is clamped to 1.0 in the factor."""
|
||||
sa = SourceAccuracy(
|
||||
source_id="src-1",
|
||||
accuracy_ratio=2.5,
|
||||
sample_count=50,
|
||||
last_updated=datetime.now(timezone.utc),
|
||||
)
|
||||
# clamped to 1.0 → factor = 0.5 + 1.0 = 1.5
|
||||
assert abs(sa.accuracy_factor - 1.5) < 1e-9
|
||||
|
||||
|
||||
def test_accuracy_factor_clamps_corrupted_negative():
|
||||
"""Corrupted accuracy_ratio < 0.0 is clamped to 0.0 in the factor."""
|
||||
sa = SourceAccuracy(
|
||||
source_id="src-1",
|
||||
accuracy_ratio=-0.3,
|
||||
sample_count=50,
|
||||
last_updated=datetime.now(timezone.utc),
|
||||
)
|
||||
# clamped to 0.0 → factor = 0.5 + 0.0 = 0.5
|
||||
assert abs(sa.accuracy_factor - 0.5) < 1e-9
|
||||
|
||||
|
||||
def test_accuracy_factor_nine_samples_neutral():
|
||||
"""sample_count=9 is still below threshold, returns 1.0."""
|
||||
sa = SourceAccuracy(
|
||||
source_id="src-1",
|
||||
accuracy_ratio=0.0,
|
||||
sample_count=9,
|
||||
last_updated=datetime.now(timezone.utc),
|
||||
)
|
||||
assert sa.accuracy_factor == 1.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_source_accuracy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_source_accuracy_empty_ids():
|
||||
"""Empty source_ids list returns empty dict without querying."""
|
||||
pool = AsyncMock()
|
||||
result = await fetch_source_accuracy(pool, [])
|
||||
assert result == {}
|
||||
pool.fetch.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_source_accuracy_returns_records():
|
||||
"""Successful fetch returns SourceAccuracy records keyed by source_id."""
|
||||
now = datetime.now(timezone.utc)
|
||||
pool = AsyncMock()
|
||||
pool.fetch = AsyncMock(return_value=[
|
||||
{
|
||||
"source_id": "src-a",
|
||||
"accuracy_ratio": 0.75,
|
||||
"sample_count": 20,
|
||||
"last_updated": now,
|
||||
},
|
||||
{
|
||||
"source_id": "src-b",
|
||||
"accuracy_ratio": 0.4,
|
||||
"sample_count": 15,
|
||||
"last_updated": now,
|
||||
},
|
||||
])
|
||||
|
||||
result = await fetch_source_accuracy(pool, ["src-a", "src-b"])
|
||||
|
||||
assert len(result) == 2
|
||||
assert result["src-a"].accuracy_ratio == 0.75
|
||||
assert result["src-a"].sample_count == 20
|
||||
assert result["src-b"].accuracy_ratio == 0.4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_source_accuracy_clamps_corrupted():
|
||||
"""Corrupted accuracy_ratio values are clamped to [0.0, 1.0]."""
|
||||
now = datetime.now(timezone.utc)
|
||||
pool = AsyncMock()
|
||||
pool.fetch = AsyncMock(return_value=[
|
||||
{
|
||||
"source_id": "src-bad",
|
||||
"accuracy_ratio": 1.5,
|
||||
"sample_count": 30,
|
||||
"last_updated": now,
|
||||
},
|
||||
])
|
||||
|
||||
result = await fetch_source_accuracy(pool, ["src-bad"])
|
||||
assert result["src-bad"].accuracy_ratio == 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_source_accuracy_db_error_returns_empty():
|
||||
"""When the database is unreachable, returns empty dict."""
|
||||
pool = AsyncMock()
|
||||
pool.fetch = AsyncMock(side_effect=Exception("connection refused"))
|
||||
|
||||
result = await fetch_source_accuracy(pool, ["src-a"])
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_source_accuracy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_source_accuracy_empty_outcomes():
|
||||
"""Empty outcomes list does nothing."""
|
||||
pool = AsyncMock()
|
||||
await update_source_accuracy(pool, "src-1", [])
|
||||
pool.execute.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_source_accuracy_counts_correctly():
|
||||
"""Correct and incorrect predictions are counted properly."""
|
||||
pool = AsyncMock()
|
||||
pool.execute = AsyncMock()
|
||||
|
||||
outcomes = [
|
||||
("bullish", 0.05), # correct
|
||||
("bullish", -0.02), # incorrect
|
||||
("bearish", -0.03), # correct
|
||||
("bearish", 0.01), # incorrect
|
||||
]
|
||||
|
||||
await update_source_accuracy(pool, "src-1", outcomes)
|
||||
|
||||
pool.execute.assert_called_once()
|
||||
call_args = pool.execute.call_args
|
||||
# accuracy_ratio = 2/4 = 0.5, total = 4
|
||||
assert abs(call_args[0][2] - 0.5) < 1e-9 # accuracy_ratio
|
||||
assert call_args[0][3] == 4 # total
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_source_accuracy_skips_neutral():
|
||||
"""Neutral predictions and zero returns are excluded."""
|
||||
pool = AsyncMock()
|
||||
pool.execute = AsyncMock()
|
||||
|
||||
outcomes = [
|
||||
("neutral", 0.05), # skipped — neutral direction
|
||||
("bullish", 0.0), # skipped — zero return
|
||||
("bullish", 0.03), # counted — correct
|
||||
]
|
||||
|
||||
await update_source_accuracy(pool, "src-1", outcomes)
|
||||
|
||||
pool.execute.assert_called_once()
|
||||
call_args = pool.execute.call_args
|
||||
# accuracy_ratio = 1/1 = 1.0, total = 1
|
||||
assert abs(call_args[0][2] - 1.0) < 1e-9
|
||||
assert call_args[0][3] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_source_accuracy_all_neutral_skips():
|
||||
"""When all outcomes are neutral/zero, no DB call is made."""
|
||||
pool = AsyncMock()
|
||||
await update_source_accuracy(pool, "src-1", [("neutral", 0.05)])
|
||||
pool.execute.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_source_accuracy_db_error_logs_and_continues():
|
||||
"""DB errors are logged but do not raise."""
|
||||
pool = AsyncMock()
|
||||
pool.execute = AsyncMock(side_effect=Exception("connection refused"))
|
||||
|
||||
# Should not raise
|
||||
await update_source_accuracy(pool, "src-1", [("bullish", 0.05)])
|
||||
Reference in New Issue
Block a user