162 lines
4.8 KiB
Python
162 lines
4.8 KiB
Python
"""Tests for Iceberg table creation and metadata management."""
|
|
from datetime import date
|
|
|
|
import pyarrow as pa
|
|
|
|
from services.lake_publisher.iceberg import (
|
|
ICEBERG_CATALOG,
|
|
ICEBERG_SCHEMA,
|
|
TABLE_SCHEMAS,
|
|
IcebergManager,
|
|
IcebergTableDef,
|
|
_arrow_type_to_trino,
|
|
get_all_table_defs,
|
|
get_table_def,
|
|
)
|
|
from services.lake_publisher.partitions import TABLE_PARTITIONS, PartitionSpec
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _arrow_type_to_trino
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_arrow_to_trino_string():
|
|
assert _arrow_type_to_trino(pa.string()) == "VARCHAR"
|
|
|
|
|
|
def test_arrow_to_trino_float64():
|
|
assert _arrow_type_to_trino(pa.float64()) == "DOUBLE"
|
|
|
|
|
|
def test_arrow_to_trino_int64():
|
|
assert _arrow_type_to_trino(pa.int64()) == "BIGINT"
|
|
|
|
|
|
def test_arrow_to_trino_int32():
|
|
assert _arrow_type_to_trino(pa.int32()) == "INTEGER"
|
|
|
|
|
|
def test_arrow_to_trino_bool():
|
|
assert _arrow_type_to_trino(pa.bool_()) == "BOOLEAN"
|
|
|
|
|
|
def test_arrow_to_trino_date32():
|
|
assert _arrow_type_to_trino(pa.date32()) == "DATE"
|
|
|
|
|
|
def test_arrow_to_trino_timestamp_utc():
|
|
assert _arrow_type_to_trino(pa.timestamp("us", tz="UTC")) == "TIMESTAMP(6) WITH TIME ZONE"
|
|
|
|
|
|
def test_arrow_to_trino_timestamp_no_tz():
|
|
assert _arrow_type_to_trino(pa.timestamp("us")) == "TIMESTAMP(6)"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TABLE_SCHEMAS registry
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_table_schemas_covers_all_partitions():
|
|
"""Every table in TABLE_PARTITIONS should have a corresponding PyArrow schema."""
|
|
for table_name in TABLE_PARTITIONS:
|
|
assert table_name in TABLE_SCHEMAS, f"Missing schema for {table_name}"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# IcebergTableDef
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_table_def_qualified_name():
|
|
td = get_table_def("trade_signals")
|
|
assert td.qualified_name == f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.trade_signals"
|
|
|
|
|
|
def test_table_def_location():
|
|
td = get_table_def("trade_signals")
|
|
assert td.location == "s3a://stonks-lakehouse/warehouse/trade_signals/"
|
|
|
|
|
|
def test_table_def_column_defs():
|
|
td = get_table_def("trade_signals")
|
|
cols = td.column_defs_sql()
|
|
col_names = [c.strip().split()[0] for c in cols]
|
|
assert "signal_id" in col_names
|
|
assert "ticker" in col_names
|
|
assert "dt" in col_names
|
|
|
|
|
|
def test_table_def_partition_keys_dt_only():
|
|
td = get_table_def("trade_signals")
|
|
part = td.partition_keys_sql()
|
|
assert "'dt'" in part
|
|
assert "model_version" not in part
|
|
|
|
|
|
def test_table_def_partition_keys_with_extra():
|
|
td = get_table_def("document_extractions")
|
|
part = td.partition_keys_sql()
|
|
assert "'dt'" in part
|
|
assert "'model_version'" in part
|
|
|
|
|
|
def test_create_table_sql_structure():
|
|
td = get_table_def("market_bars")
|
|
sql = td.create_table_sql()
|
|
assert "CREATE TABLE IF NOT EXISTS" in sql
|
|
assert f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.market_bars" in sql
|
|
assert "format = 'PARQUET'" in sql
|
|
assert "s3a://stonks-lakehouse/warehouse/market_bars/" in sql
|
|
assert "partitioning" in sql
|
|
|
|
|
|
def test_create_table_sql_columns_match_schema():
|
|
td = get_table_def("market_bars")
|
|
sql = td.create_table_sql()
|
|
# All columns from the PyArrow schema should appear
|
|
for i in range(len(td.schema)):
|
|
col_name = td.schema.field(i).name
|
|
assert col_name in sql, f"Column {col_name} missing from DDL"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_all_table_defs / get_table_def
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_get_all_table_defs_count():
|
|
defs = get_all_table_defs()
|
|
assert len(defs) == len(TABLE_PARTITIONS)
|
|
|
|
|
|
def test_get_table_def_unknown_raises():
|
|
try:
|
|
get_table_def("nonexistent_table")
|
|
assert False, "Should have raised ValueError"
|
|
except ValueError:
|
|
pass
|
|
|
|
|
|
def test_get_all_table_defs_all_generate_valid_sql():
|
|
"""Every table def should produce syntactically reasonable DDL."""
|
|
for td in get_all_table_defs():
|
|
sql = td.create_table_sql()
|
|
assert "CREATE TABLE IF NOT EXISTS" in sql
|
|
assert td.table_name in sql
|
|
assert "PARQUET" in sql
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# IcebergManager (unit tests with no Trino connection)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_iceberg_manager_defaults():
|
|
mgr = IcebergManager()
|
|
assert mgr.host == "localhost"
|
|
assert mgr.port == 8080
|
|
assert mgr.catalog == ICEBERG_CATALOG
|
|
assert mgr.schema == ICEBERG_SCHEMA
|