78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
"""Database migration runner using asyncpg.
|
|
|
|
Applies all SQL migration files from infra/migrations/ in sorted order.
|
|
Each file is split on semicolons and executed statement-by-statement.
|
|
Idempotent — migrations use IF NOT EXISTS / CREATE OR REPLACE patterns.
|
|
|
|
Usage:
|
|
python -m services.shared.migrate
|
|
"""
|
|
import asyncio
|
|
import glob
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
import asyncpg
|
|
|
|
logger = logging.getLogger("migrate")
|
|
|
|
|
|
async def run_migrations() -> None:
|
|
host = os.getenv("POSTGRES_HOST", "localhost")
|
|
port = int(os.getenv("POSTGRES_PORT", "5432"))
|
|
user = os.getenv("POSTGRES_USER", "stonks")
|
|
password = os.getenv("POSTGRES_PASSWORD", "")
|
|
database = os.getenv("POSTGRES_DB", "stonks")
|
|
|
|
migrations_dir = os.path.join(
|
|
os.path.dirname(__file__), "..", "..", "infra", "migrations"
|
|
)
|
|
migrations_dir = os.path.normpath(migrations_dir)
|
|
|
|
if not os.path.isdir(migrations_dir):
|
|
logger.error("Migrations directory not found: %s", migrations_dir)
|
|
sys.exit(1)
|
|
|
|
files = sorted(glob.glob(os.path.join(migrations_dir, "*.sql")))
|
|
if not files:
|
|
logger.warning("No migration files found in %s", migrations_dir)
|
|
return
|
|
|
|
logger.info("Connecting to %s@%s:%d/%s", user, host, port, database)
|
|
conn = await asyncpg.connect(
|
|
host=host, port=port, user=user, password=password, database=database
|
|
)
|
|
|
|
try:
|
|
for path in files:
|
|
name = os.path.basename(path)
|
|
with open(path) as f:
|
|
sql = f.read()
|
|
# Split on semicolons and execute each statement individually.
|
|
# asyncpg.execute() doesn't support multi-statement strings.
|
|
statements = [s.strip() for s in sql.split(";") if s.strip()]
|
|
try:
|
|
for stmt in statements:
|
|
await conn.execute(stmt)
|
|
logger.info(" ✓ %s (%d statements)", name, len(statements))
|
|
except Exception as exc:
|
|
logger.warning(" ⚠ %s: %s", name, exc)
|
|
finally:
|
|
await conn.close()
|
|
|
|
logger.info("Migrations complete (%d files)", len(files))
|
|
|
|
|
|
def main() -> None:
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(name)s %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
asyncio.run(run_migrations())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|