From af3b97bfa4b3ed4b223384c9ed3fa0643204d8c9 Mon Sep 17 00:00:00 2001 From: Ilya Gurov Date: Thu, 16 Sep 2021 05:46:25 +0300 Subject: [PATCH] fix: array columns reflection (#119) Fixes #118 --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 42 +++++++++++++------ test/test_suite.py | 28 +++++++++++++ 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 45993354..8a362104 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -42,6 +42,7 @@ "STRING": types.String, "TIME": types.TIME, "TIMESTAMP": types.TIMESTAMP, + "ARRAY": types.ARRAY, } _type_map_inv = { @@ -476,28 +477,43 @@ def get_columns(self, connection, table_name, schema=None, **kw): columns = snap.execute_sql(sql) for col in columns: - if col[1].startswith("STRING"): - end = col[1].index(")") - size = int_from_size(col[1][7:end]) - type_ = _type_map["STRING"](length=size) - # add test creating a table with bytes - elif col[1].startswith("BYTES"): - end = col[1].index(")") - size = int_from_size(col[1][6:end]) - type_ = _type_map["BYTES"](length=size) - else: - type_ = _type_map[col[1]] - cols_desc.append( { "name": col[0], - "type": type_, + "type": self._designate_type(col[1]), "nullable": col[2] == "YES", "default": None, } ) return cols_desc + def _designate_type(self, str_repr): + """ + Designate an SQLAlchemy data type from a Spanner + string representation. + + Args: + str_repr (str): String representation of a type. + + Returns: + An SQLAlchemy data type. + """ + if str_repr.startswith("STRING"): + end = str_repr.index(")") + size = int_from_size(str_repr[7:end]) + return _type_map["STRING"](length=size) + # add test creating a table with bytes + elif str_repr.startswith("BYTES"): + end = str_repr.index(")") + size = int_from_size(str_repr[6:end]) + return _type_map["BYTES"](length=size) + elif str_repr.startswith("ARRAY"): + inner_type_str = str_repr[6:-1] + inner_type = self._designate_type(inner_type_str) + return _type_map["ARRAY"](inner_type) + else: + return _type_map[str_repr] + @engine_to_connection def get_indexes(self, connection, table_name, schema=None, **kw): """Get the table indexes. diff --git a/test/test_suite.py b/test/test_suite.py index 74345b72..aa8d3f25 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -49,6 +49,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relation from sqlalchemy.orm import Session +from sqlalchemy.types import ARRAY from sqlalchemy.types import Integer from sqlalchemy.types import Numeric from sqlalchemy.types import Text @@ -901,6 +902,33 @@ def test_binary_reflection(self): assert isinstance(typ, LargeBinary) eq_(typ.length, 20) + @testing.requires.table_reflection + def test_array_reflection(self): + """Check array columns reflection.""" + orig_meta = self.metadata + + str_array = ARRAY(String(16)) + int_array = ARRAY(Integer) + Table( + "arrays_test", + orig_meta, + Column("id", Integer, primary_key=True), + Column("str_array", str_array), + Column("int_array", int_array), + ) + orig_meta.create_all() + + # autoload the table and check its columns reflection + tab = Table("arrays_test", orig_meta, autoload=True) + col_types = [col.type for col in tab.columns] + for type_ in ( + str_array, + int_array, + ): + assert type_ in col_types + + tab.drop() + class CompositeKeyReflectionTest(_CompositeKeyReflectionTest): @testing.requires.foreign_key_constraint_reflection