"""Tests for Iceberg table creation and metadata management.""" from datetime import date import pyarrow as pa from services.lake_publisher.iceberg import ( ICEBERG_CATALOG, ICEBERG_SCHEMA, TABLE_SCHEMAS, IcebergManager, IcebergTableDef, _arrow_type_to_trino, get_all_table_defs, get_table_def, ) from services.lake_publisher.partitions import TABLE_PARTITIONS, PartitionSpec # --------------------------------------------------------------------------- # _arrow_type_to_trino # --------------------------------------------------------------------------- def test_arrow_to_trino_string(): assert _arrow_type_to_trino(pa.string()) == "VARCHAR" def test_arrow_to_trino_float64(): assert _arrow_type_to_trino(pa.float64()) == "DOUBLE" def test_arrow_to_trino_int64(): assert _arrow_type_to_trino(pa.int64()) == "BIGINT" def test_arrow_to_trino_int32(): assert _arrow_type_to_trino(pa.int32()) == "INTEGER" def test_arrow_to_trino_bool(): assert _arrow_type_to_trino(pa.bool_()) == "BOOLEAN" def test_arrow_to_trino_date32(): assert _arrow_type_to_trino(pa.date32()) == "DATE" def test_arrow_to_trino_timestamp_utc(): assert _arrow_type_to_trino(pa.timestamp("us", tz="UTC")) == "TIMESTAMP(6) WITH TIME ZONE" def test_arrow_to_trino_timestamp_no_tz(): assert _arrow_type_to_trino(pa.timestamp("us")) == "TIMESTAMP(6)" # --------------------------------------------------------------------------- # TABLE_SCHEMAS registry # --------------------------------------------------------------------------- def test_table_schemas_covers_all_partitions(): """Every table in TABLE_PARTITIONS should have a corresponding PyArrow schema.""" for table_name in TABLE_PARTITIONS: assert table_name in TABLE_SCHEMAS, f"Missing schema for {table_name}" # --------------------------------------------------------------------------- # IcebergTableDef # --------------------------------------------------------------------------- def test_table_def_qualified_name(): td = get_table_def("trade_signals") assert td.qualified_name == f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.trade_signals" def test_table_def_location(): td = get_table_def("trade_signals") assert td.location == "s3a://stonks-lakehouse/warehouse/trade_signals/" def test_table_def_column_defs(): td = get_table_def("trade_signals") cols = td.column_defs_sql() col_names = [c.strip().split()[0] for c in cols] assert "signal_id" in col_names assert "ticker" in col_names assert "dt" in col_names def test_table_def_partition_keys_dt_only(): td = get_table_def("trade_signals") part = td.partition_keys_sql() assert "'dt'" in part assert "model_version" not in part def test_table_def_partition_keys_with_extra(): td = get_table_def("document_extractions") part = td.partition_keys_sql() assert "'dt'" in part assert "'model_version'" in part def test_create_table_sql_structure(): td = get_table_def("market_bars") sql = td.create_table_sql() assert "CREATE TABLE IF NOT EXISTS" in sql assert f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.market_bars" in sql assert "format = 'PARQUET'" in sql assert "s3a://stonks-lakehouse/warehouse/market_bars/" in sql assert "partitioning" in sql def test_create_table_sql_columns_match_schema(): td = get_table_def("market_bars") sql = td.create_table_sql() # All columns from the PyArrow schema should appear for i in range(len(td.schema)): col_name = td.schema.field(i).name assert col_name in sql, f"Column {col_name} missing from DDL" # --------------------------------------------------------------------------- # get_all_table_defs / get_table_def # --------------------------------------------------------------------------- def test_get_all_table_defs_count(): defs = get_all_table_defs() assert len(defs) == len(TABLE_PARTITIONS) def test_get_table_def_unknown_raises(): try: get_table_def("nonexistent_table") assert False, "Should have raised ValueError" except ValueError: pass def test_get_all_table_defs_all_generate_valid_sql(): """Every table def should produce syntactically reasonable DDL.""" for td in get_all_table_defs(): sql = td.create_table_sql() assert "CREATE TABLE IF NOT EXISTS" in sql assert td.table_name in sql assert "PARQUET" in sql # --------------------------------------------------------------------------- # IcebergManager (unit tests with no Trino connection) # --------------------------------------------------------------------------- def test_iceberg_manager_defaults(): mgr = IcebergManager() assert mgr.host == "localhost" assert mgr.port == 8080 assert mgr.catalog == ICEBERG_CATALOG assert mgr.schema == ICEBERG_SCHEMA