597 lines
18 KiB
Python
597 lines
18 KiB
Python
"""Tests for lake publisher worker — writing prediction facts as Parquet to MinIO."""
|
|
from datetime import date, datetime, timezone
|
|
from unittest.mock import MagicMock
|
|
|
|
import pyarrow.parquet as pq
|
|
|
|
from services.lake_publisher.partitions import (
|
|
LAKEHOUSE_BUCKET,
|
|
TABLE_PARTITIONS,
|
|
PartitionSpec,
|
|
partition_path,
|
|
partition_values,
|
|
s3_uri,
|
|
)
|
|
from services.lake_publisher.worker import (
|
|
_parse_horizon_days,
|
|
_partition_path,
|
|
build_trade_signal_row,
|
|
publish_trade_signal,
|
|
publish_prediction_fact,
|
|
publish_recommendation_facts,
|
|
build_trade_order_row,
|
|
publish_trade_order,
|
|
build_trade_fill_row,
|
|
publish_trade_fill,
|
|
build_position_daily_row,
|
|
publish_position_daily,
|
|
publish_positions_daily_batch,
|
|
build_model_performance_row,
|
|
publish_model_performance,
|
|
publish_market_bars_batch,
|
|
publish_trade_signals_batch,
|
|
publish_model_performance_batch,
|
|
)
|
|
from services.shared.schemas import (
|
|
ActionType,
|
|
ModelMetadata,
|
|
PositionSizing,
|
|
Recommendation,
|
|
RecommendationMode,
|
|
)
|
|
|
|
NOW = datetime(2026, 4, 11, 14, 30, 0, tzinfo=timezone.utc)
|
|
|
|
|
|
def _make_rec(
|
|
ticker: str = "AAPL",
|
|
action: ActionType = ActionType.BUY,
|
|
confidence: float = 0.72,
|
|
time_horizon: str = "swing_1d_10d",
|
|
rec_id: str = "rec-001",
|
|
) -> Recommendation:
|
|
return Recommendation(
|
|
recommendation_id=rec_id,
|
|
ticker=ticker,
|
|
action=action,
|
|
mode=RecommendationMode.PAPER_ELIGIBLE,
|
|
confidence=confidence,
|
|
time_horizon=time_horizon,
|
|
thesis="[risk:low] Test thesis",
|
|
invalidation_conditions=["price drops below support"],
|
|
position_sizing=PositionSizing(portfolio_pct=0.02, max_loss_pct=0.005),
|
|
evidence_refs=["doc1", "doc2"],
|
|
model_metadata=ModelMetadata(provider="deterministic", model_name="eligibility-v1"),
|
|
generated_at=NOW,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _parse_horizon_days
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_parse_horizon_days_swing():
|
|
assert _parse_horizon_days("swing_1d_10d") == 10
|
|
|
|
|
|
def test_parse_horizon_days_position():
|
|
assert _parse_horizon_days("position_10d_30d") == 30
|
|
|
|
|
|
def test_parse_horizon_days_intraday():
|
|
assert _parse_horizon_days("scalp_intraday") == 1
|
|
|
|
|
|
def test_parse_horizon_days_empty():
|
|
assert _parse_horizon_days("") == 0
|
|
|
|
|
|
def test_parse_horizon_days_no_numbers():
|
|
assert _parse_horizon_days("unknown") == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Partition module tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_partition_spec_all_keys():
|
|
spec = PartitionSpec("test_table", extra_keys=("model_version",))
|
|
assert spec.all_keys == ("dt", "model_version")
|
|
|
|
|
|
def test_partition_spec_dt_only():
|
|
spec = PartitionSpec("simple")
|
|
assert spec.all_keys == ("dt",)
|
|
|
|
|
|
def test_table_partitions_registry():
|
|
assert "market_bars" in TABLE_PARTITIONS
|
|
assert "model_performance" in TABLE_PARTITIONS
|
|
assert TABLE_PARTITIONS["model_performance"].extra_keys == ("model_version",)
|
|
assert TABLE_PARTITIONS["prediction_vs_outcome"].extra_keys == ("model_version",)
|
|
assert TABLE_PARTITIONS["document_extractions"].extra_keys == ("model_version",)
|
|
assert TABLE_PARTITIONS["trade_signals"].extra_keys == ()
|
|
|
|
|
|
def test_partition_values_dt_only():
|
|
pv = partition_values(NOW)
|
|
assert pv == {"dt": date(2026, 4, 11)}
|
|
|
|
|
|
def test_partition_values_with_extras():
|
|
pv = partition_values(NOW, {"model_version": "v2"})
|
|
assert pv == {"dt": date(2026, 4, 11), "model_version": "v2"}
|
|
|
|
|
|
def test_s3_uri():
|
|
assert s3_uri("warehouse/t/dt=2026-04-11/part-abc.parquet") == \
|
|
"s3://stonks-lakehouse/warehouse/t/dt=2026-04-11/part-abc.parquet"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _partition_path (via partitions module)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_partition_path_basic():
|
|
path = partition_path("trade_signals", NOW)
|
|
assert path.startswith("warehouse/trade_signals/dt=2026-04-11/")
|
|
assert path.endswith(".parquet")
|
|
|
|
|
|
def test_partition_path_with_extra_partitions():
|
|
path = partition_path("model_performance", NOW, {"model_version": "v1"})
|
|
assert "model_version=v1" in path
|
|
|
|
|
|
def test_partition_path_custom_file_id():
|
|
path = partition_path("trade_signals", NOW, file_id="custom123")
|
|
assert "part-custom123.parquet" in path
|
|
|
|
|
|
def test_partition_path_legacy_wrapper():
|
|
"""The _partition_path wrapper in worker.py still works."""
|
|
path = _partition_path("trade_signals", NOW)
|
|
assert path.startswith("warehouse/trade_signals/dt=2026-04-11/")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# build_trade_signal_row
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_build_trade_signal_row():
|
|
rec = _make_rec()
|
|
row = build_trade_signal_row(rec, trend_direction="bullish", trend_strength=0.68)
|
|
assert row["signal_id"] == "rec-001"
|
|
assert row["ticker"] == "AAPL"
|
|
assert row["trend_direction"] == "bullish"
|
|
assert row["trend_strength"] == 0.68
|
|
assert row["confidence"] == 0.72
|
|
assert row["action"] == "buy"
|
|
assert row["time_horizon"] == "swing_1d_10d"
|
|
assert row["generated_at"] == NOW
|
|
assert row["dt"] == date(2026, 4, 11)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# publish_trade_signal
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_trade_signal_writes_parquet():
|
|
client = MagicMock()
|
|
rec = _make_rec()
|
|
|
|
ref = publish_trade_signal(client, rec, trend_direction="bullish", trend_strength=0.68)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_signals/")
|
|
assert client.put_object.call_count == 1
|
|
|
|
# Verify the written bytes are valid Parquet
|
|
put_call = client.put_object.call_args
|
|
assert put_call[0][0] == LAKEHOUSE_BUCKET
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 1
|
|
assert table.column("ticker")[0].as_py() == "AAPL"
|
|
assert table.column("action")[0].as_py() == "buy"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# publish_prediction_fact
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_prediction_fact_writes_parquet():
|
|
client = MagicMock()
|
|
rec = _make_rec()
|
|
|
|
ref = publish_prediction_fact(client, rec)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/prediction_vs_outcome/")
|
|
assert "model_version=" in ref
|
|
assert client.put_object.call_count == 1
|
|
|
|
put_call = client.put_object.call_args
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 1
|
|
assert table.column("predicted_action")[0].as_py() == "buy"
|
|
assert table.column("outcome")[0].as_py() == "pending"
|
|
assert table.column("horizon_days")[0].as_py() == 10
|
|
assert table.column("dt")[0].as_py() == date(2026, 4, 11)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# publish_recommendation_facts
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_recommendation_facts_writes_both_tables():
|
|
client = MagicMock()
|
|
rec = _make_rec()
|
|
|
|
refs = publish_recommendation_facts(client, rec, "bullish", 0.68)
|
|
|
|
assert "trade_signals" in refs
|
|
assert "prediction_vs_outcome" in refs
|
|
assert client.put_object.call_count == 2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# build_trade_order_row
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_build_trade_order_row():
|
|
row = build_trade_order_row(
|
|
order_id="ord-001",
|
|
ticker="AAPL",
|
|
side="buy",
|
|
order_type="market",
|
|
quantity=10.0,
|
|
limit_price=None,
|
|
status="filled",
|
|
broker_account="acct-001",
|
|
submitted_at=NOW,
|
|
)
|
|
assert row["order_id"] == "ord-001"
|
|
assert row["ticker"] == "AAPL"
|
|
assert row["side"] == "buy"
|
|
assert row["order_type"] == "market"
|
|
assert row["quantity"] == 10.0
|
|
assert row["limit_price"] is None
|
|
assert row["status"] == "filled"
|
|
assert row["broker_account"] == "acct-001"
|
|
assert row["submitted_at"] == NOW
|
|
assert row["dt"] == date(2026, 4, 11)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# publish_trade_order
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_trade_order_writes_parquet():
|
|
client = MagicMock()
|
|
|
|
ref = publish_trade_order(
|
|
client,
|
|
order_id="ord-001",
|
|
ticker="AAPL",
|
|
side="buy",
|
|
order_type="market",
|
|
quantity=10.0,
|
|
limit_price=None,
|
|
status="filled",
|
|
broker_account="acct-001",
|
|
submitted_at=NOW,
|
|
)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_orders/")
|
|
assert client.put_object.call_count == 1
|
|
|
|
put_call = client.put_object.call_args
|
|
assert put_call[0][0] == LAKEHOUSE_BUCKET
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 1
|
|
assert table.column("ticker")[0].as_py() == "AAPL"
|
|
assert table.column("side")[0].as_py() == "buy"
|
|
assert table.column("status")[0].as_py() == "filled"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# build_trade_fill_row
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_build_trade_fill_row():
|
|
row = build_trade_fill_row(
|
|
fill_id="fill-001",
|
|
order_id="ord-001",
|
|
ticker="AAPL",
|
|
side="buy",
|
|
fill_price=150.25,
|
|
fill_quantity=10.0,
|
|
broker_account="acct-001",
|
|
filled_at=NOW,
|
|
)
|
|
assert row["fill_id"] == "fill-001"
|
|
assert row["order_id"] == "ord-001"
|
|
assert row["fill_price"] == 150.25
|
|
assert row["fill_quantity"] == 10.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# publish_trade_fill
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_trade_fill_writes_parquet():
|
|
client = MagicMock()
|
|
|
|
ref = publish_trade_fill(
|
|
client,
|
|
fill_id="fill-001",
|
|
order_id="ord-001",
|
|
ticker="AAPL",
|
|
side="buy",
|
|
fill_price=150.25,
|
|
fill_quantity=10.0,
|
|
broker_account="acct-001",
|
|
filled_at=NOW,
|
|
)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_fills/")
|
|
assert client.put_object.call_count == 1
|
|
|
|
put_call = client.put_object.call_args
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 1
|
|
assert table.column("fill_price")[0].as_py() == 150.25
|
|
assert table.column("ticker")[0].as_py() == "AAPL"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# build_position_daily_row
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_build_position_daily_row():
|
|
row = build_position_daily_row(
|
|
ticker="AAPL",
|
|
quantity=100.0,
|
|
avg_entry_price=145.00,
|
|
close_price=150.00,
|
|
unrealized_pnl=500.0,
|
|
broker_account="acct-001",
|
|
snapshot_at=NOW,
|
|
)
|
|
assert row["ticker"] == "AAPL"
|
|
assert row["quantity"] == 100.0
|
|
assert row["unrealized_pnl"] == 500.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# publish_position_daily
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_position_daily_writes_parquet():
|
|
client = MagicMock()
|
|
|
|
ref = publish_position_daily(
|
|
client,
|
|
ticker="AAPL",
|
|
quantity=100.0,
|
|
avg_entry_price=145.00,
|
|
close_price=150.00,
|
|
unrealized_pnl=500.0,
|
|
broker_account="acct-001",
|
|
snapshot_at=NOW,
|
|
)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/positions_daily/")
|
|
assert client.put_object.call_count == 1
|
|
|
|
put_call = client.put_object.call_args
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 1
|
|
assert table.column("ticker")[0].as_py() == "AAPL"
|
|
assert table.column("close_price")[0].as_py() == 150.00
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# publish_positions_daily_batch
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_positions_daily_batch_writes_parquet():
|
|
client = MagicMock()
|
|
|
|
positions = [
|
|
{"ticker": "AAPL", "quantity": 100.0, "avg_entry_price": 145.0, "close_price": 150.0, "unrealized_pnl": 500.0},
|
|
{"ticker": "MSFT", "quantity": 50.0, "avg_entry_price": 300.0, "close_price": 310.0, "unrealized_pnl": 500.0},
|
|
]
|
|
|
|
ref = publish_positions_daily_batch(client, positions, "acct-001", NOW)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/positions_daily/")
|
|
assert client.put_object.call_count == 1
|
|
|
|
put_call = client.put_object.call_args
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 2
|
|
|
|
|
|
def test_publish_positions_daily_batch_empty():
|
|
client = MagicMock()
|
|
|
|
ref = publish_positions_daily_batch(client, [], "acct-001", NOW)
|
|
|
|
assert ref == ""
|
|
assert client.put_object.call_count == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# build_model_performance_row
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_build_model_performance_row():
|
|
row = build_model_performance_row(
|
|
document_id="doc-001",
|
|
model_name="gpt-oss:20b",
|
|
success=True,
|
|
total_duration_ms=1500,
|
|
recorded_at=NOW,
|
|
ticker="AAPL",
|
|
prompt_version="document-intel-v2",
|
|
schema_version="2.0.0",
|
|
attempt_count=2,
|
|
confidence=0.86,
|
|
validation_status="valid",
|
|
retry_count=1,
|
|
input_token_estimate=500,
|
|
output_token_estimate=200,
|
|
company_count=3,
|
|
)
|
|
assert row["document_id"] == "doc-001"
|
|
assert row["model_name"] == "gpt-oss:20b"
|
|
assert row["success"] is True
|
|
assert row["total_duration_ms"] == 1500
|
|
assert row["attempt_count"] == 2
|
|
assert row["confidence"] == 0.86
|
|
assert row["company_count"] == 3
|
|
assert row["dt"] == date(2026, 4, 11)
|
|
assert row["model_version"] == "2.0.0"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# publish_model_performance
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_model_performance_writes_parquet():
|
|
client = MagicMock()
|
|
|
|
ref = publish_model_performance(
|
|
client,
|
|
document_id="doc-001",
|
|
model_name="gpt-oss:20b",
|
|
success=True,
|
|
total_duration_ms=1500,
|
|
recorded_at=NOW,
|
|
ticker="AAPL",
|
|
prompt_version="document-intel-v2",
|
|
schema_version="2.0.0",
|
|
confidence=0.86,
|
|
validation_status="valid",
|
|
)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/model_performance/")
|
|
assert "model_version=2.0.0" in ref
|
|
assert client.put_object.call_count == 1
|
|
|
|
put_call = client.put_object.call_args
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 1
|
|
assert table.column("model_name")[0].as_py() == "gpt-oss:20b"
|
|
assert table.column("success")[0].as_py() is True
|
|
assert table.column("confidence")[0].as_py() == 0.86
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Batch publish helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_publish_market_bars_batch():
|
|
client = MagicMock()
|
|
bars: list[dict[str, object]] = [
|
|
{
|
|
"ticker": "AAPL", "open_price": 150.0, "high_price": 155.0,
|
|
"low_price": 149.0, "close_price": 153.0, "volume": 1000000,
|
|
"vwap": 152.0, "trade_count": 5000, "bar_timestamp": NOW,
|
|
"bar_interval": "1d", "source": "test",
|
|
"dt": date(2026, 4, 11),
|
|
},
|
|
{
|
|
"ticker": "MSFT", "open_price": 300.0, "high_price": 310.0,
|
|
"low_price": 298.0, "close_price": 305.0, "volume": 800000,
|
|
"vwap": 304.0, "trade_count": 4000, "bar_timestamp": NOW,
|
|
"bar_interval": "1d", "source": "test",
|
|
"dt": date(2026, 4, 11),
|
|
},
|
|
]
|
|
|
|
ref = publish_market_bars_batch(client, bars, NOW)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/market_bars/")
|
|
assert client.put_object.call_count == 1
|
|
|
|
put_call = client.put_object.call_args
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 2
|
|
|
|
|
|
def test_publish_batch_empty_returns_empty():
|
|
client = MagicMock()
|
|
ref = publish_market_bars_batch(client, [], NOW)
|
|
assert ref == ""
|
|
assert client.put_object.call_count == 0
|
|
|
|
|
|
def test_publish_trade_signals_batch():
|
|
client = MagicMock()
|
|
rec = _make_rec()
|
|
rows = [build_trade_signal_row(rec, "bullish", 0.68)]
|
|
|
|
ref = publish_trade_signals_batch(client, rows, NOW)
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/trade_signals/")
|
|
assert client.put_object.call_count == 1
|
|
|
|
|
|
def test_publish_model_performance_batch():
|
|
client = MagicMock()
|
|
rows = [
|
|
build_model_performance_row(
|
|
document_id="doc-001", model_name="gpt-oss:20b",
|
|
success=True, total_duration_ms=1500, recorded_at=NOW,
|
|
),
|
|
build_model_performance_row(
|
|
document_id="doc-002", model_name="gpt-oss:20b",
|
|
success=False, total_duration_ms=3000, recorded_at=NOW,
|
|
),
|
|
]
|
|
|
|
ref = publish_model_performance_batch(client, rows, NOW, model_version="v2")
|
|
|
|
assert ref.startswith(f"s3://{LAKEHOUSE_BUCKET}/warehouse/model_performance/")
|
|
assert "model_version=v2" in ref
|
|
assert client.put_object.call_count == 1
|
|
|
|
put_call = client.put_object.call_args
|
|
written_buf = put_call[0][2]
|
|
written_buf.seek(0)
|
|
table = pq.read_table(written_buf)
|
|
assert table.num_rows == 2
|