"""Tests for lake publisher worker — writing prediction facts as Parquet to MinIO.""" from datetime import date, datetime, timezone from unittest.mock import MagicMock import pyarrow.parquet as pq from services.lake_publisher.partitions import ( LAKEHOUSE_BUCKET, TABLE_PARTITIONS, PartitionSpec, partition_path, partition_values, s3_uri, ) from services.lake_publisher.worker import ( _parse_horizon_days, _partition_path, build_model_performance_row, build_position_daily_row, build_trade_fill_row, build_trade_order_row, build_trade_signal_row, publish_market_bars_batch, publish_model_performance, publish_model_performance_batch, publish_position_daily, publish_positions_daily_batch, publish_prediction_fact, publish_recommendation_facts, publish_trade_fill, publish_trade_order, publish_trade_signal, publish_trade_signals_batch, ) from services.shared.schemas import ( ActionType, ModelMetadata, PositionSizing, Recommendation, RecommendationMode, ) NOW = datetime(2026, 4, 11, 14, 30, 0, tzinfo=timezone.utc) def _make_rec( ticker: str = "AAPL", action: ActionType = ActionType.BUY, confidence: float = 0.72, time_horizon: str = "swing_1d_10d", rec_id: str = "rec-001", ) -> Recommendation: return Recommendation( recommendation_id=rec_id, ticker=ticker, action=action, mode=RecommendationMode.PAPER_ELIGIBLE, confidence=confidence, time_horizon=time_horizon, thesis="[risk:low] Test thesis", invalidation_conditions=["price drops below support"], position_sizing=PositionSizing(portfolio_pct=0.02, max_loss_pct=0.005), evidence_refs=["doc1", "doc2"], model_metadata=ModelMetadata(provider="deterministic", model_name="eligibility-v1"), generated_at=NOW, ) # --------------------------------------------------------------------------- # _parse_horizon_days # --------------------------------------------------------------------------- def test_parse_horizon_days_swing(): assert _parse_horizon_days("swing_1d_10d") == 10 def test_parse_horizon_days_position(): assert _parse_horizon_days("position_10d_30d") == 30 def test_parse_horizon_days_intraday(): assert _parse_horizon_days("scalp_intraday") == 1 def test_parse_horizon_days_empty(): assert _parse_horizon_days("") == 0 def test_parse_horizon_days_no_numbers(): assert _parse_horizon_days("unknown") == 0 # --------------------------------------------------------------------------- # Partition module tests # --------------------------------------------------------------------------- def test_partition_spec_all_keys(): spec = PartitionSpec("test_table", extra_keys=("model_version",)) assert spec.all_keys == ("dt", "model_version") def test_partition_spec_dt_only(): spec = PartitionSpec("simple") assert spec.all_keys == ("dt",) def test_table_partitions_registry(): assert "market_bars" in TABLE_PARTITIONS assert "model_performance" in TABLE_PARTITIONS assert TABLE_PARTITIONS["model_performance"].extra_keys == ("model_version",) assert TABLE_PARTITIONS["prediction_vs_outcome"].extra_keys == ("model_version",) assert TABLE_PARTITIONS["document_extractions"].extra_keys == ("model_version",) assert TABLE_PARTITIONS["trade_signals"].extra_keys == () def test_partition_values_dt_only(): pv = partition_values(NOW) assert pv == {"dt": date(2026, 4, 11)} def test_partition_values_with_extras(): pv = partition_values(NOW, {"model_version": "v2"}) assert pv == {"dt": date(2026, 4, 11), "model_version": "v2"} def test_s3_uri(): assert s3_uri("warehouse/t/dt=2026-04-11/part-abc.parquet") == \ "s3://stonks-lakehouse/warehouse/t/dt=2026-04-11/part-abc.parquet" # --------------------------------------------------------------------------- # _partition_path (via partitions module) # --------------------------------------------------------------------------- def test_partition_path_basic(): path = partition_path("trade_signals", NOW) assert path.startswith("warehouse/trade_signals/dt=2026-04-11/") assert path.endswith(".parquet") def test_partition_path_with_extra_partitions(): path = partition_path("model_performance", NOW, {"model_version": "v1"}) assert "model_version=v1" in path def test_partition_path_custom_file_id(): path = partition_path("trade_signals", NOW, file_id="custom123") assert "part-custom123.parquet" in path def test_partition_path_legacy_wrapper(): """The _partition_path wrapper in worker.py still works.""" path = _partition_path("trade_signals", NOW) assert path.startswith("warehouse/trade_signals/dt=2026-04-11/") # --------------------------------------------------------------------------- # build_trade_signal_row # --------------------------------------------------------------------------- def test_build_trade_signal_row(): rec = _make_rec() row = build_trade_signal_row(rec, trend_direction="bullish", trend_strength=0.68) assert row["signal_id"] == "rec-001" assert row["ticker"] == "AAPL" assert row["trend_direction"] == "bullish" assert row["trend_strength"] == 0.68 assert row["confidence"] == 0.72 assert row["action"] == "buy" assert row["time_horizon"] == "swing_1d_10d" assert row["generated_at"] == NOW assert row["dt"] == date(2026, 4, 11) # --------------------------------------------------------------------------- # publish_trade_signal # --------------------------------------------------------------------------- def test_publish_trade_signal_writes_parquet(): client = MagicMock() rec = _make_rec() ref = publish_trade_signal(client, rec, trend_direction="bullish", trend_strength=0.68) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_signals/") assert client.put_object.call_count == 1 # Verify the written bytes are valid Parquet put_call = client.put_object.call_args assert put_call[0][0] == LAKEHOUSE_BUCKET written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 1 assert table.column("ticker")[0].as_py() == "AAPL" assert table.column("action")[0].as_py() == "buy" # --------------------------------------------------------------------------- # publish_prediction_fact # --------------------------------------------------------------------------- def test_publish_prediction_fact_writes_parquet(): client = MagicMock() rec = _make_rec() ref = publish_prediction_fact(client, rec) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/prediction_vs_outcome/") assert "model_version=" in ref assert client.put_object.call_count == 1 put_call = client.put_object.call_args written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 1 assert table.column("predicted_action")[0].as_py() == "buy" assert table.column("outcome")[0].as_py() == "pending" assert table.column("horizon_days")[0].as_py() == 10 assert table.column("dt")[0].as_py() == date(2026, 4, 11) # --------------------------------------------------------------------------- # publish_recommendation_facts # --------------------------------------------------------------------------- def test_publish_recommendation_facts_writes_both_tables(): client = MagicMock() rec = _make_rec() refs = publish_recommendation_facts(client, rec, "bullish", 0.68) assert "trade_signals" in refs assert "prediction_vs_outcome" in refs assert client.put_object.call_count == 2 # --------------------------------------------------------------------------- # build_trade_order_row # --------------------------------------------------------------------------- def test_build_trade_order_row(): row = build_trade_order_row( order_id="ord-001", ticker="AAPL", side="buy", order_type="market", quantity=10.0, limit_price=None, status="filled", broker_account="acct-001", submitted_at=NOW, ) assert row["order_id"] == "ord-001" assert row["ticker"] == "AAPL" assert row["side"] == "buy" assert row["order_type"] == "market" assert row["quantity"] == 10.0 assert row["limit_price"] is None assert row["status"] == "filled" assert row["broker_account"] == "acct-001" assert row["submitted_at"] == NOW assert row["dt"] == date(2026, 4, 11) # --------------------------------------------------------------------------- # publish_trade_order # --------------------------------------------------------------------------- def test_publish_trade_order_writes_parquet(): client = MagicMock() ref = publish_trade_order( client, order_id="ord-001", ticker="AAPL", side="buy", order_type="market", quantity=10.0, limit_price=None, status="filled", broker_account="acct-001", submitted_at=NOW, ) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_orders/") assert client.put_object.call_count == 1 put_call = client.put_object.call_args assert put_call[0][0] == LAKEHOUSE_BUCKET written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 1 assert table.column("ticker")[0].as_py() == "AAPL" assert table.column("side")[0].as_py() == "buy" assert table.column("status")[0].as_py() == "filled" # --------------------------------------------------------------------------- # build_trade_fill_row # --------------------------------------------------------------------------- def test_build_trade_fill_row(): row = build_trade_fill_row( fill_id="fill-001", order_id="ord-001", ticker="AAPL", side="buy", fill_price=150.25, fill_quantity=10.0, broker_account="acct-001", filled_at=NOW, ) assert row["fill_id"] == "fill-001" assert row["order_id"] == "ord-001" assert row["fill_price"] == 150.25 assert row["fill_quantity"] == 10.0 # --------------------------------------------------------------------------- # publish_trade_fill # --------------------------------------------------------------------------- def test_publish_trade_fill_writes_parquet(): client = MagicMock() ref = publish_trade_fill( client, fill_id="fill-001", order_id="ord-001", ticker="AAPL", side="buy", fill_price=150.25, fill_quantity=10.0, broker_account="acct-001", filled_at=NOW, ) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_fills/") assert client.put_object.call_count == 1 put_call = client.put_object.call_args written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 1 assert table.column("fill_price")[0].as_py() == 150.25 assert table.column("ticker")[0].as_py() == "AAPL" # --------------------------------------------------------------------------- # build_position_daily_row # --------------------------------------------------------------------------- def test_build_position_daily_row(): row = build_position_daily_row( ticker="AAPL", quantity=100.0, avg_entry_price=145.00, close_price=150.00, unrealized_pnl=500.0, broker_account="acct-001", snapshot_at=NOW, ) assert row["ticker"] == "AAPL" assert row["quantity"] == 100.0 assert row["unrealized_pnl"] == 500.0 # --------------------------------------------------------------------------- # publish_position_daily # --------------------------------------------------------------------------- def test_publish_position_daily_writes_parquet(): client = MagicMock() ref = publish_position_daily( client, ticker="AAPL", quantity=100.0, avg_entry_price=145.00, close_price=150.00, unrealized_pnl=500.0, broker_account="acct-001", snapshot_at=NOW, ) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/positions_daily/") assert client.put_object.call_count == 1 put_call = client.put_object.call_args written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 1 assert table.column("ticker")[0].as_py() == "AAPL" assert table.column("close_price")[0].as_py() == 150.00 # --------------------------------------------------------------------------- # publish_positions_daily_batch # --------------------------------------------------------------------------- def test_publish_positions_daily_batch_writes_parquet(): client = MagicMock() positions = [ {"ticker": "AAPL", "quantity": 100.0, "avg_entry_price": 145.0, "close_price": 150.0, "unrealized_pnl": 500.0}, {"ticker": "MSFT", "quantity": 50.0, "avg_entry_price": 300.0, "close_price": 310.0, "unrealized_pnl": 500.0}, ] ref = publish_positions_daily_batch(client, positions, "acct-001", NOW) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/positions_daily/") assert client.put_object.call_count == 1 put_call = client.put_object.call_args written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 2 def test_publish_positions_daily_batch_empty(): client = MagicMock() ref = publish_positions_daily_batch(client, [], "acct-001", NOW) assert ref == "" assert client.put_object.call_count == 0 # --------------------------------------------------------------------------- # build_model_performance_row # --------------------------------------------------------------------------- def test_build_model_performance_row(): row = build_model_performance_row( document_id="doc-001", model_name="gpt-oss:20b", success=True, total_duration_ms=1500, recorded_at=NOW, ticker="AAPL", prompt_version="document-intel-v2", schema_version="2.0.0", attempt_count=2, confidence=0.86, validation_status="valid", retry_count=1, input_token_estimate=500, output_token_estimate=200, company_count=3, ) assert row["document_id"] == "doc-001" assert row["model_name"] == "gpt-oss:20b" assert row["success"] is True assert row["total_duration_ms"] == 1500 assert row["attempt_count"] == 2 assert row["confidence"] == 0.86 assert row["company_count"] == 3 assert row["dt"] == date(2026, 4, 11) assert row["model_version"] == "2.0.0" # --------------------------------------------------------------------------- # publish_model_performance # --------------------------------------------------------------------------- def test_publish_model_performance_writes_parquet(): client = MagicMock() ref = publish_model_performance( client, document_id="doc-001", model_name="gpt-oss:20b", success=True, total_duration_ms=1500, recorded_at=NOW, ticker="AAPL", prompt_version="document-intel-v2", schema_version="2.0.0", confidence=0.86, validation_status="valid", ) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/model_performance/") assert "model_version=2.0.0" in ref assert client.put_object.call_count == 1 put_call = client.put_object.call_args written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 1 assert table.column("model_name")[0].as_py() == "gpt-oss:20b" assert table.column("success")[0].as_py() is True assert table.column("confidence")[0].as_py() == 0.86 # --------------------------------------------------------------------------- # Batch publish helpers # --------------------------------------------------------------------------- def test_publish_market_bars_batch(): client = MagicMock() bars: list[dict[str, object]] = [ { "ticker": "AAPL", "open_price": 150.0, "high_price": 155.0, "low_price": 149.0, "close_price": 153.0, "volume": 1000000, "vwap": 152.0, "trade_count": 5000, "bar_timestamp": NOW, "bar_interval": "1d", "source": "test", "dt": date(2026, 4, 11), }, { "ticker": "MSFT", "open_price": 300.0, "high_price": 310.0, "low_price": 298.0, "close_price": 305.0, "volume": 800000, "vwap": 304.0, "trade_count": 4000, "bar_timestamp": NOW, "bar_interval": "1d", "source": "test", "dt": date(2026, 4, 11), }, ] ref = publish_market_bars_batch(client, bars, NOW) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/market_bars/") assert client.put_object.call_count == 1 put_call = client.put_object.call_args written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 2 def test_publish_batch_empty_returns_empty(): client = MagicMock() ref = publish_market_bars_batch(client, [], NOW) assert ref == "" assert client.put_object.call_count == 0 def test_publish_trade_signals_batch(): client = MagicMock() rec = _make_rec() rows = [build_trade_signal_row(rec, "bullish", 0.68)] ref = publish_trade_signals_batch(client, rows, NOW) assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_signals/") assert client.put_object.call_count == 1 def test_publish_model_performance_batch(): client = MagicMock() rows = [ build_model_performance_row( document_id="doc-001", model_name="gpt-oss:20b", success=True, total_duration_ms=1500, recorded_at=NOW, ), build_model_performance_row( document_id="doc-002", model_name="gpt-oss:20b", success=False, total_duration_ms=3000, recorded_at=NOW, ), ] ref = publish_model_performance_batch(client, rows, NOW, model_version="v2") assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/model_performance/") assert "model_version=v2" in ref assert client.put_object.call_count == 1 put_call = client.put_object.call_args written_buf = put_call[0][2] written_buf.seek(0) table = pq.read_table(written_buf) assert table.num_rows == 2