"""Unit tests for services.signal_engine.signals.elliott_wave — Elliott Wave evaluator. Requirements: 2.5, 2.6, 2.7 """ from __future__ import annotations from datetime import datetime, timezone from services.signal_engine.models import OHLCVBar, SignalDirection from services.signal_engine.signals.elliott_wave import ( DEFAULT_MIN_BARS, WAVE_TYPE_CORRECTIVE, WAVE_TYPE_IMPULSE, ElliottWaveEvaluator, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _bar( close: float, high: float | None = None, low: float | None = None, ) -> OHLCVBar: """Create a minimal OHLCVBar for testing.""" h = high if high is not None else close lo = low if low is not None else close return OHLCVBar( timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), open=close, high=h, low=lo, close=close, volume=1000.0, ) def _make_impulse_up_bars(n: int = 50) -> list[OHLCVBar]: """Create synthetic bars forming a bullish 5-wave impulse pattern. Wave structure (bullish impulse): Wave 1: 100 → 120 (up) Wave 2: 120 → 108 (down, retracement) Wave 3: 108 → 140 (up, largest wave) Wave 4: 140 → 130 (down, retracement) Wave 5: 130 → 150 (up, new high) """ # Define price waypoints for each wave waypoints = [ (0.00, 100.0), # start (0.20, 120.0), # wave 1 peak (0.35, 108.0), # wave 2 trough (0.60, 140.0), # wave 3 peak (0.75, 130.0), # wave 4 trough (1.00, 150.0), # wave 5 peak ] return _interpolate_bars(waypoints, n) def _make_impulse_down_bars(n: int = 50) -> list[OHLCVBar]: """Create synthetic bars forming a bearish 5-wave impulse pattern. Wave structure (bearish impulse): Wave 1: 150 → 130 (down) Wave 2: 130 → 142 (up, retracement) Wave 3: 142 → 110 (down, largest wave) Wave 4: 110 → 120 (up, retracement) Wave 5: 120 → 100 (down, new low) """ waypoints = [ (0.00, 150.0), (0.20, 130.0), (0.35, 142.0), (0.60, 110.0), (0.75, 120.0), (1.00, 100.0), ] return _interpolate_bars(waypoints, n) def _make_corrective_bars(n: int = 50) -> list[OHLCVBar]: """Create synthetic bars forming a corrective A-B-C pattern after an uptrend. First half: uptrend (impulse context) Second half: A-B-C correction Wave A: 150 → 130 (down) Wave B: 130 → 140 (up, partial retracement) Wave C: 140 → 120 (down, new low) """ waypoints = [ (0.00, 100.0), # start of uptrend (0.40, 150.0), # end of uptrend / start of correction (0.60, 130.0), # wave A trough (0.75, 140.0), # wave B peak (1.00, 120.0), # wave C trough ] return _interpolate_bars(waypoints, n) def _interpolate_bars( waypoints: list[tuple[float, float]], n: int, ) -> list[OHLCVBar]: """Interpolate price waypoints into n OHLCV bars with realistic high/low.""" bars: list[OHLCVBar] = [] for i in range(n): frac = i / max(1, n - 1) # Find the two surrounding waypoints price = waypoints[-1][1] # default to last for j in range(len(waypoints) - 1): t0, p0 = waypoints[j] t1, p1 = waypoints[j + 1] if t0 <= frac <= t1: seg_frac = (frac - t0) / (t1 - t0) if t1 > t0 else 0.0 price = p0 + seg_frac * (p1 - p0) break # Add some spread for high/low spread = max(1.0, abs(price) * 0.01) bars.append(_bar(price, high=price + spread, low=price - spread)) return bars # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- def test_default_min_bars() -> None: assert DEFAULT_MIN_BARS == 30 # --------------------------------------------------------------------------- # Insufficient data → None (Requirement 2.6) # --------------------------------------------------------------------------- def test_returns_none_when_insufficient_bars() -> None: """Requirement 2.6: return None when fewer than min_bars.""" evaluator = ElliottWaveEvaluator() bars = [_bar(100.0) for _ in range(29)] assert evaluator.evaluate(bars, "D") is None def test_returns_none_with_empty_bars() -> None: evaluator = ElliottWaveEvaluator() assert evaluator.evaluate([], "D") is None def test_returns_none_with_one_bar() -> None: evaluator = ElliottWaveEvaluator() assert evaluator.evaluate([_bar(100.0)], "D") is None # --------------------------------------------------------------------------- # Flat market → None # --------------------------------------------------------------------------- def test_returns_none_for_flat_market() -> None: """Flat prices have no wave structure.""" evaluator = ElliottWaveEvaluator() bars = [_bar(100.0, high=100.0, low=100.0) for _ in range(40)] assert evaluator.evaluate(bars, "D") is None # --------------------------------------------------------------------------- # Impulse wave detection (Requirement 2.5) # --------------------------------------------------------------------------- def test_detects_bullish_impulse_wave() -> None: """Requirement 2.5: detect impulse waves (5-wave structure).""" evaluator = ElliottWaveEvaluator() bars = _make_impulse_up_bars(n=50) result = evaluator.evaluate(bars, "D") assert result is not None assert result.signal_type == "elliott_wave" assert result.direction == SignalDirection.BULLISH assert result.metadata["wave_type"] == WAVE_TYPE_IMPULSE def test_detects_bearish_impulse_wave() -> None: """Requirement 2.5: detect bearish impulse waves.""" evaluator = ElliottWaveEvaluator() bars = _make_impulse_down_bars(n=50) result = evaluator.evaluate(bars, "D") assert result is not None assert result.signal_type == "elliott_wave" assert result.direction == SignalDirection.BEARISH assert result.metadata["wave_type"] == WAVE_TYPE_IMPULSE # --------------------------------------------------------------------------- # Corrective wave detection (Requirement 2.5) # --------------------------------------------------------------------------- def test_detects_corrective_wave() -> None: """Requirement 2.5: detect corrective waves (3-wave structure).""" evaluator = ElliottWaveEvaluator() bars = _make_corrective_bars(n=50) result = evaluator.evaluate(bars, "D") assert result is not None assert result.signal_type == "elliott_wave" assert result.metadata["wave_type"] in (WAVE_TYPE_CORRECTIVE, WAVE_TYPE_IMPULSE) # --------------------------------------------------------------------------- # Signal structure validation (Requirement 2.7) # --------------------------------------------------------------------------- def test_signal_result_structure() -> None: """Requirement 2.7: SignalResult has all required fields.""" evaluator = ElliottWaveEvaluator() bars = _make_impulse_up_bars(n=50) result = evaluator.evaluate(bars, "D") assert result is not None assert result.signal_type == "elliott_wave" assert result.timeframe == "D" assert 0.0 <= result.strength <= 1.0 assert 0.0 <= result.confidence <= 1.0 assert result.direction in ( SignalDirection.BULLISH, SignalDirection.BEARISH, SignalDirection.NEUTRAL, ) def test_strength_in_unit_interval() -> None: """Strength must be in [0, 1].""" evaluator = ElliottWaveEvaluator() bars = _make_impulse_up_bars(n=50) result = evaluator.evaluate(bars, "D") assert result is not None assert 0.0 <= result.strength <= 1.0 def test_confidence_in_unit_interval() -> None: """Confidence must be in [0, 1].""" evaluator = ElliottWaveEvaluator() bars = _make_impulse_up_bars(n=50) result = evaluator.evaluate(bars, "D") assert result is not None assert 0.0 <= result.confidence <= 1.0 # --------------------------------------------------------------------------- # Metadata (Requirement 2.7) # --------------------------------------------------------------------------- def test_metadata_contains_required_fields() -> None: """Metadata should include wave_count, wave_type, current_wave_position, pivots.""" evaluator = ElliottWaveEvaluator() bars = _make_impulse_up_bars(n=50) result = evaluator.evaluate(bars, "D") assert result is not None meta = result.metadata assert "wave_count" in meta assert "wave_type" in meta assert "current_wave_position" in meta assert "pivots" in meta assert isinstance(meta["pivots"], list) assert len(meta["pivots"]) > 0 # --------------------------------------------------------------------------- # Timeframe passthrough # --------------------------------------------------------------------------- def test_timeframe_passthrough() -> None: """The timeframe label is passed through to the result.""" evaluator = ElliottWaveEvaluator() bars = _make_impulse_up_bars(n=50) for tf in ("M30", "H1", "H4", "D", "W", "M"): result = evaluator.evaluate(bars, tf) assert result is not None assert result.timeframe == tf # --------------------------------------------------------------------------- # Custom min_bars # --------------------------------------------------------------------------- def test_custom_min_bars() -> None: """ElliottWaveEvaluator with a custom min_bars should use that value.""" evaluator = ElliottWaveEvaluator(min_bars=60) assert evaluator.min_bars == 60 # 50 bars should be insufficient bars = _make_impulse_up_bars(n=50) assert evaluator.evaluate(bars, "D") is None def test_custom_zigzag_pct() -> None: """Custom zigzag_pct should be stored and used.""" evaluator = ElliottWaveEvaluator(zigzag_pct=0.10) assert evaluator.zigzag_pct == 0.10