diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 8a362104..1d0cda07 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -15,10 +15,17 @@ import pkg_resources import re -from sqlalchemy import types, ForeignKeyConstraint +from alembic.ddl.base import ( + ColumnNullable, + ColumnType, + alter_column, + alter_table, + format_type, +) +from sqlalchemy import ForeignKeyConstraint, types, util from sqlalchemy.engine.base import Engine from sqlalchemy.engine.default import DefaultDialect -from sqlalchemy import util +from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.compiler import ( selectable, DDLCompiler, @@ -27,6 +34,7 @@ SQLCompiler, RESERVED_WORDS, ) + from google.cloud import spanner_dbapi from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call @@ -864,3 +872,28 @@ def do_execute_no_params(self, cursor, statement, context=None): } with trace_call("SpannerSqlAlchemy.ExecuteNoParams", trace_attributes): cursor.execute(statement) + + +# Alembic ALTER operation override +@compiles(ColumnNullable, "spanner") +def visit_column_nullable( + element: "ColumnNullable", compiler: "SpannerDDLCompiler", **kw +) -> str: + return "%s %s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + format_type(compiler, element.existing_type), + "" if element.nullable else "NOT NULL", + ) + + +# Alembic ALTER operation override +@compiles(ColumnType, "spanner") +def visit_column_type( + element: "ColumnType", compiler: "SpannerDDLCompiler", **kw +) -> str: + return "%s %s %s" % ( + alter_table(compiler, element.table_name, element.schema), + alter_column(compiler, element.column_name), + "%s" % format_type(compiler, element.type_), + ) diff --git a/migration_test_cleanup.py b/migration_test_cleanup.py index 485f2a42..01a9a528 100644 --- a/migration_test_cleanup.py +++ b/migration_test_cleanup.py @@ -27,11 +27,11 @@ config.read("setup.cfg") db_url = config.get("db", "default") -project = re.findall(r'projects(.*?)instances', db_url) -instance_id = re.findall(r'instances(.*?)databases', db_url) +project = re.findall(r"projects(.*?)instances", db_url) +instance_id = re.findall(r"instances(.*?)databases", db_url) -client = spanner.Client(project="".join(project).replace('/', '')) -instance = client.instance(instance_id="".join(instance_id).replace('/', '')) +client = spanner.Client(project="".join(project).replace("/", "")) +instance = client.instance(instance_id="".join(instance_id).replace("/", "")) database = instance.database("compliance-test") database.update_ddl(["DROP TABLE account", "DROP TABLE alembic_version"]).result(120) diff --git a/noxfile.py b/noxfile.py index d10f44cd..1e5eb9ed 100644 --- a/noxfile.py +++ b/noxfile.py @@ -60,6 +60,12 @@ class = StreamHandler sa.Column('id', sa.Integer, primary_key=True), sa.Column('name', sa.String(50), nullable=False), sa.Column('description', sa.Unicode(200)), + ) + op.alter_column( + 'account', + 'name', + existing_type=sa.String(50), + nullable=True, )""" diff --git a/setup.py b/setup.py index 3852af3d..09d67985 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,11 @@ name = "sqlalchemy-spanner" description = "SQLAlchemy dialect integrated into Cloud Spanner database" -dependencies = ["sqlalchemy>=1.1.13, <=1.3.23", "google-cloud-spanner>=3.3.0"] +dependencies = [ + "sqlalchemy>=1.1.13, <=1.3.23", + "google-cloud-spanner>=3.3.0", + "alembic", +] extras = { "tracing": [ "opentelemetry-api >= 1.1.0",