Skip to content

Commit

Permalink
feat: support computed columns (#139)
Browse files Browse the repository at this point in the history
Closes #137
  • Loading branch information
Ilya Gurov committed Nov 19, 2021
1 parent d80cb27 commit 046ca97
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 8 deletions.
7 changes: 7 additions & 0 deletions google/cloud/sqlalchemy_spanner/requirements.py
Expand Up @@ -18,6 +18,13 @@

class Requirements(SuiteRequirements):
@property
def computed_columns(self):
return exclusions.open()

@property
def computed_columns_stored(self):
return exclusions.open()

def sane_rowcount(self):
return exclusions.closed()

Expand Down
29 changes: 21 additions & 8 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Expand Up @@ -267,6 +267,13 @@ def limit_clause(self, select, **kw):
class SpannerDDLCompiler(DDLCompiler):
"""Spanner DDL statements compiler."""

def visit_computed_column(self, generated, **kw):
"""Computed column operator."""
text = "AS (%s) STORED" % self.sql_compiler.process(
generated.sqltext, include_table=False, literal_binds=True
)
return text

def visit_drop_table(self, drop_table):
"""
Cloud Spanner doesn't drop tables which have indexes
Expand Down Expand Up @@ -492,7 +499,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
list: The table every column dict-like description.
"""
sql = """
SELECT column_name, spanner_type, is_nullable
SELECT column_name, spanner_type, is_nullable, generation_expression
FROM information_schema.columns
WHERE
table_catalog = ''
Expand All @@ -512,14 +519,20 @@ def get_columns(self, connection, table_name, schema=None, **kw):
columns = snap.execute_sql(sql)

for col in columns:
cols_desc.append(
{
"name": col[0],
"type": self._designate_type(col[1]),
"nullable": col[2] == "YES",
"default": None,
col_desc = {
"name": col[0],
"type": self._designate_type(col[1]),
"nullable": col[2] == "YES",
"default": None,
}

if col[3] is not None:
col_desc["computed"] = {
"persisted": True,
"sqltext": col[3],
}
)
cols_desc.append(col_desc)

return cols_desc

def _designate_type(self, str_repr):
Expand Down
98 changes: 98 additions & 0 deletions test/test_suite.py
Expand Up @@ -29,11 +29,13 @@
from sqlalchemy import ForeignKey
from sqlalchemy import MetaData
from sqlalchemy.schema import DDL
from sqlalchemy.schema import Computed
from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import provide_metadata, emits_warning
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_true
from sqlalchemy.testing.provision import temp_table_keyword_args
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
Expand All @@ -54,6 +56,9 @@
from sqlalchemy.types import Numeric
from sqlalchemy.types import Text
from sqlalchemy.testing import requires
from sqlalchemy.testing.fixtures import (
ComputedReflectionFixtureTest as _ComputedReflectionFixtureTest,
)

from google.api_core.datetime_helpers import DatetimeWithNanoseconds

Expand Down Expand Up @@ -89,6 +94,7 @@
QuotedNameArgumentTest as _QuotedNameArgumentTest,
ComponentReflectionTest as _ComponentReflectionTest,
CompositeKeyReflectionTest as _CompositeKeyReflectionTest,
ComputedReflectionTest as _ComputedReflectionTest,
)
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403
Expand Down Expand Up @@ -1608,3 +1614,95 @@ def test_staleness(self):

with self._engine.connect() as connection:
assert connection.connection.staleness is None


class ComputedReflectionFixtureTest(_ComputedReflectionFixtureTest):
@classmethod
def define_tables(cls, metadata):
"""SPANNER OVERRIDE:
Avoid using default values for computed columns.
"""
Table(
"computed_default_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_col", Integer, Computed("normal + 42")),
Column("with_default", Integer),
)

t = Table(
"computed_column_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_no_flag", Integer, Computed("normal + 42")),
)

if testing.requires.schemas.enabled:
t2 = Table(
"computed_column_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_no_flag", Integer, Computed("normal / 42")),
schema=config.test_schema,
)

if testing.requires.computed_columns_virtual.enabled:
t.append_column(
Column(
"computed_virtual",
Integer,
Computed("normal + 2", persisted=False),
)
)
if testing.requires.schemas.enabled:
t2.append_column(
Column(
"computed_virtual",
Integer,
Computed("normal / 2", persisted=False),
)
)
if testing.requires.computed_columns_stored.enabled:
t.append_column(
Column(
"computed_stored", Integer, Computed("normal - 42", persisted=True),
)
)
if testing.requires.schemas.enabled:
t2.append_column(
Column(
"computed_stored",
Integer,
Computed("normal * 42", persisted=True),
)
)


class ComputedReflectionTest(_ComputedReflectionTest, ComputedReflectionFixtureTest):
@pytest.mark.skip("Default values are not supported.")
def test_computed_col_default_not_set(self):
pass

def test_get_column_returns_computed(self):
"""
SPANNER OVERRIDE:
In Spanner all the generated columns are STORED,
meaning there are no persisted and not persisted
(in the terms of the SQLAlchemy) columns. The
method override omits the persistence reflection checks.
"""
insp = inspect(config.db)

cols = insp.get_columns("computed_default_table")
data = {c["name"]: c for c in cols}
for key in ("id", "normal", "with_default"):
is_true("computed" not in data[key])
compData = data["computed_col"]
is_true("computed" in compData)
is_true("sqltext" in compData["computed"])
eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")

0 comments on commit 046ca97

Please sign in to comment.