Files
stonks-oracle/services/lake_publisher/iceberg.py
T

421 lines
15 KiB
Python

"""Iceberg table creation and metadata management for analytical datasets.
Manages Iceberg tables in Trino's Iceberg catalog, providing:
- Table creation with proper schemas and partition specs
- Schema synchronization between PyArrow definitions and Iceberg tables
- Table metadata inspection (existence checks, schema retrieval, partition listing)
The Iceberg catalog complements the existing Hive-compatible partition layout.
Parquet files written by the lake publisher are stored in the same MinIO paths,
but Iceberg metadata enables schema evolution, snapshot isolation, and better
partition pruning via Trino's Iceberg connector.
Requirements: 9.4, 9.5, 10.1, N4, N6
Design ref: Section 5.3 (Lakehouse model), Section 4.12 (SQL Query Engine)
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
import pyarrow as pa
from trino.dbapi import connect as trino_connect
from services.lake_publisher.partitions import (
LAKEHOUSE_BUCKET,
TABLE_PARTITIONS,
WAREHOUSE_PREFIX,
PartitionSpec,
)
from services.lake_publisher.worker import (
COMPANY_EVENTS_SCHEMA,
DOCUMENTS_SCHEMA,
DOCUMENT_EXTRACTIONS_SCHEMA,
MARKET_BARS_SCHEMA,
MARKET_QUOTES_SCHEMA,
MODEL_PERFORMANCE_SCHEMA,
PNL_DAILY_SCHEMA,
POSITIONS_DAILY_SCHEMA,
PREDICTION_VS_OUTCOME_SCHEMA,
TRADE_FILLS_SCHEMA,
TRADE_ORDERS_SCHEMA,
TRADE_SIGNALS_SCHEMA,
)
logger = logging.getLogger(__name__)
ICEBERG_CATALOG = "iceberg"
ICEBERG_SCHEMA = "stonks"
def _get_iceberg_catalog() -> str:
"""Return the Iceberg catalog name from env or default."""
import os
return os.getenv("TRINO_ICEBERG_CATALOG", ICEBERG_CATALOG)
# Map PyArrow types to Trino/Iceberg SQL types.
_ARROW_TO_TRINO: dict[str, str] = {
"string": "VARCHAR",
"utf8": "VARCHAR",
"large_string": "VARCHAR",
"large_utf8": "VARCHAR",
"float64": "DOUBLE",
"double": "DOUBLE",
"float32": "REAL",
"float": "REAL",
"int8": "TINYINT",
"int16": "SMALLINT",
"int32": "INTEGER",
"int64": "BIGINT",
"bool": "BOOLEAN",
"date32": "DATE",
"date32[day]": "DATE",
"date64": "DATE",
}
def _arrow_type_to_trino(arrow_type: pa.DataType) -> str:
"""Convert a PyArrow data type to a Trino SQL type string."""
type_str = str(arrow_type)
# Handle timestamp types (with or without timezone)
if type_str.startswith("timestamp"):
if "tz=" in type_str:
return "TIMESTAMP(6) WITH TIME ZONE"
return "TIMESTAMP(6)"
# Direct lookup
result = _ARROW_TO_TRINO.get(type_str)
if result:
return result
# Fallback for type IDs
if pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type):
return "VARCHAR"
if pa.types.is_floating(arrow_type):
return "DOUBLE"
if pa.types.is_integer(arrow_type):
return "BIGINT"
if pa.types.is_boolean(arrow_type):
return "BOOLEAN"
if pa.types.is_date(arrow_type):
return "DATE"
if pa.types.is_timestamp(arrow_type):
return "TIMESTAMP(6) WITH TIME ZONE"
raise ValueError(f"Unsupported PyArrow type for Iceberg DDL: {arrow_type}")
# Registry mapping table names to their PyArrow schemas.
TABLE_SCHEMAS: dict[str, pa.Schema] = {
"market_bars": MARKET_BARS_SCHEMA,
"market_quotes": MARKET_QUOTES_SCHEMA,
"company_events": COMPANY_EVENTS_SCHEMA,
"documents": DOCUMENTS_SCHEMA,
"document_extractions": DOCUMENT_EXTRACTIONS_SCHEMA,
"trade_signals": TRADE_SIGNALS_SCHEMA,
"trade_orders": TRADE_ORDERS_SCHEMA,
"trade_fills": TRADE_FILLS_SCHEMA,
"positions_daily": POSITIONS_DAILY_SCHEMA,
"pnl_daily": PNL_DAILY_SCHEMA,
"prediction_vs_outcome": PREDICTION_VS_OUTCOME_SCHEMA,
"model_performance": MODEL_PERFORMANCE_SCHEMA,
}
@dataclass(frozen=True)
class IcebergTableDef:
"""Definition for an Iceberg table derived from PyArrow schema + partition spec."""
table_name: str
schema: pa.Schema
partition_spec: PartitionSpec
@property
def qualified_name(self) -> str:
return f"{ICEBERG_CATALOG}.{ICEBERG_SCHEMA}.{self.table_name}"
@property
def location(self) -> str:
return f"s3a://{LAKEHOUSE_BUCKET}/{WAREHOUSE_PREFIX}/{self.table_name}/"
def column_defs_sql(self) -> list[str]:
"""Generate SQL column definitions from the PyArrow schema.
Partition columns are included in the column list (Iceberg stores them
in the data files, unlike Hive external tables).
"""
cols: list[str] = []
for i in range(len(self.schema)):
name = self.schema.field(i).name
arrow_type = self.schema.field(i).type
trino_type = _arrow_type_to_trino(arrow_type)
cols.append(f" {name} {trino_type}")
return cols
def partition_keys_sql(self) -> str:
"""Generate the partitioning clause for CREATE TABLE."""
keys = list(self.partition_spec.all_keys)
if not keys:
return ""
quoted = ", ".join(f"'{k}'" for k in keys)
return f"partitioning = ARRAY[{quoted}]"
def create_table_sql(self) -> str:
"""Generate a CREATE TABLE IF NOT EXISTS statement for Trino's Iceberg catalog."""
col_lines = ",\n".join(self.column_defs_sql())
with_clauses = [
"format = 'PARQUET'",
f"location = '{self.location}'",
]
part_sql = self.partition_keys_sql()
if part_sql:
with_clauses.append(part_sql)
with_block = ",\n ".join(with_clauses)
return (
f"CREATE TABLE IF NOT EXISTS {self.qualified_name} (\n"
f"{col_lines}\n"
f") WITH (\n"
f" {with_block}\n"
f")"
)
def get_all_table_defs() -> list[IcebergTableDef]:
"""Build IcebergTableDef for every registered analytical table."""
defs: list[IcebergTableDef] = []
for table_name, partition_spec in TABLE_PARTITIONS.items():
schema = TABLE_SCHEMAS.get(table_name)
if schema is None:
logger.warning("No PyArrow schema for table %s, skipping", table_name)
continue
defs.append(IcebergTableDef(
table_name=table_name,
schema=schema,
partition_spec=partition_spec,
))
return defs
def get_table_def(table_name: str) -> IcebergTableDef:
"""Get the IcebergTableDef for a single table by name."""
if table_name not in TABLE_PARTITIONS:
raise ValueError(f"Unknown table: {table_name}")
schema = TABLE_SCHEMAS.get(table_name)
if schema is None:
raise ValueError(f"No PyArrow schema registered for table: {table_name}")
return IcebergTableDef(
table_name=table_name,
schema=schema,
partition_spec=TABLE_PARTITIONS[table_name],
)
@dataclass
class IcebergManager:
"""Manages Iceberg tables via Trino's Iceberg catalog.
Provides table creation, existence checks, schema inspection,
and metadata operations against the Trino Iceberg connector.
"""
host: str = "localhost"
port: int = 8080
user: str = "stonks"
catalog: str = ICEBERG_CATALOG
schema: str = ICEBERG_SCHEMA
def _get_connection(self) -> Any:
"""Create a Trino DBAPI connection."""
return trino_connect(
host=self.host,
port=self.port,
user=self.user,
catalog=self.catalog,
schema=self.schema,
)
def _execute(self, sql: str) -> list[list[Any]]:
"""Execute a SQL statement and return all rows."""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql)
return cursor.fetchall()
finally:
conn.close()
def _execute_no_fetch(self, sql: str) -> None:
"""Execute a DDL statement that returns no rows."""
conn = self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(sql)
# DDL statements in Trino still need fetchall to complete
try:
cursor.fetchall()
except Exception:
pass
finally:
conn.close()
def ensure_schema(self) -> None:
"""Create the Iceberg schema if it doesn't exist."""
sql = f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.schema}"
logger.info("Ensuring Iceberg schema: %s.%s", self.catalog, self.schema)
self._execute_no_fetch(sql)
def table_exists(self, table_name: str) -> bool:
"""Check if an Iceberg table exists."""
sql = (
f"SELECT table_name FROM {self.catalog}.information_schema.tables "
f"WHERE table_schema = '{self.schema}' AND table_name = '{table_name}'"
)
rows = self._execute(sql)
return len(rows) > 0
def create_table(self, table_name: str) -> bool:
"""Create a single Iceberg table if it doesn't exist.
Returns True if the table was created, False if it already existed.
"""
table_def = get_table_def(table_name)
ddl = table_def.create_table_sql()
logger.info("Creating Iceberg table: %s", table_def.qualified_name)
self._execute_no_fetch(ddl)
logger.info("Iceberg table ready: %s", table_def.qualified_name)
return True
def create_all_tables(self) -> dict[str, bool]:
"""Create all registered Iceberg tables.
Returns a dict mapping table_name -> True (created) or False (error).
"""
self.ensure_schema()
results: dict[str, bool] = {}
for table_def in get_all_table_defs():
try:
self.create_table(table_def.table_name)
results[table_def.table_name] = True
except Exception:
logger.exception("Failed to create Iceberg table: %s", table_def.table_name)
results[table_def.table_name] = False
return results
def get_table_schema(self, table_name: str) -> list[dict[str, str]]:
"""Retrieve the column schema of an Iceberg table from Trino.
Returns a list of dicts with 'column_name', 'data_type', and 'is_nullable'.
"""
sql = (
f"SELECT column_name, data_type, is_nullable "
f"FROM {self.catalog}.information_schema.columns "
f"WHERE table_schema = '{self.schema}' AND table_name = '{table_name}' "
f"ORDER BY ordinal_position"
)
rows = self._execute(sql)
return [
{"column_name": r[0], "data_type": r[1], "is_nullable": r[2]}
for r in rows
]
def get_table_snapshots(self, table_name: str) -> list[dict[str, Any]]:
"""List Iceberg snapshots for a table (useful for auditing and rollback).
Returns snapshot metadata from Trino's $snapshots metadata table.
"""
qualified = f"{self.catalog}.{self.schema}.{table_name}"
sql = f'SELECT * FROM "{qualified}$snapshots"'
try:
rows = self._execute(sql)
return [{"snapshot_id": r[0], "parent_id": r[1], "operation": r[2],
"manifest_list": r[3], "summary": r[4]} for r in rows]
except Exception:
logger.debug("Could not read snapshots for %s (table may be empty)", table_name)
return []
def get_table_partitions(self, table_name: str) -> list[dict[str, Any]]:
"""List partition values for an Iceberg table.
Returns partition metadata from Trino's $partitions metadata table.
"""
qualified = f"{self.catalog}.{self.schema}.{table_name}"
sql = f'SELECT * FROM "{qualified}$partitions"'
try:
rows = self._execute(sql)
return [{"row": r} for r in rows]
except Exception:
logger.debug("Could not read partitions for %s (table may be empty)", table_name)
return []
def list_tables(self) -> list[str]:
"""List all tables in the Iceberg schema."""
sql = (
f"SELECT table_name FROM {self.catalog}.information_schema.tables "
f"WHERE table_schema = '{self.schema}' ORDER BY table_name"
)
rows = self._execute(sql)
return [r[0] for r in rows]
def drop_table(self, table_name: str) -> None:
"""Drop an Iceberg table (for testing/reset purposes)."""
qualified = f"{self.catalog}.{self.schema}.{table_name}"
logger.warning("Dropping Iceberg table: %s", qualified)
self._execute_no_fetch(f"DROP TABLE IF EXISTS {qualified}")
def sync_table_schema(self, table_name: str) -> list[str]:
"""Compare the expected PyArrow schema with the actual Iceberg table schema.
If columns are missing from the Iceberg table, adds them via ALTER TABLE.
Returns a list of columns that were added.
This supports forward-only schema evolution — columns are never dropped.
"""
table_def = get_table_def(table_name)
existing = self.get_table_schema(table_name)
existing_names = {col["column_name"] for col in existing}
added: list[str] = []
qualified = table_def.qualified_name
for i in range(len(table_def.schema)):
col_name = table_def.schema.field(i).name
if col_name not in existing_names:
trino_type = _arrow_type_to_trino(table_def.schema.field(i).type)
alter_sql = f"ALTER TABLE {qualified} ADD COLUMN {col_name} {trino_type}"
logger.info("Adding column %s to %s", col_name, qualified)
self._execute_no_fetch(alter_sql)
added.append(col_name)
return added
def sync_all_schemas(self) -> dict[str, list[str]]:
"""Sync schemas for all registered tables. Returns table_name -> added columns."""
results: dict[str, list[str]] = {}
for table_def in get_all_table_defs():
try:
if self.table_exists(table_def.table_name):
added = self.sync_table_schema(table_def.table_name)
results[table_def.table_name] = added
else:
logger.info("Table %s doesn't exist yet, skipping sync", table_def.table_name)
results[table_def.table_name] = []
except Exception:
logger.exception("Failed to sync schema for %s", table_def.table_name)
results[table_def.table_name] = []
return results
def create_iceberg_manager_from_config(
host: str = "localhost",
port: int = 8080,
user: str = "stonks",
) -> IcebergManager:
"""Factory that creates an IcebergManager from explicit connection params."""
return IcebergManager(host=host, port=port, user=user)