From 9b5b0025ec0b65177c0df02013ac387b3d3de472 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Wed, 25 Aug 2021 11:29:04 -0600 Subject: [PATCH] fix: unnest failed in some cases (with table references failed when there were no other references to refrenced tables in a query) (#290) --- setup.py | 8 +- sqlalchemy_bigquery/__init__.py | 42 ++++---- sqlalchemy_bigquery/base.py | 101 ++++++++++++++---- .../test_dialect_compliance.py | 3 +- tests/system/test_sqlalchemy_bigquery.py | 30 +++++- tests/unit/conftest.py | 12 ++- tests/unit/test_select.py | 62 ++++++++++- 7 files changed, 208 insertions(+), 50 deletions(-) diff --git a/setup.py b/setup.py index fd8a8acd..f70c3a0d 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,11 @@ def readme(): return f.read() -extras = dict(geography=["GeoAlchemy2", "shapely"], alembic=["alembic"], tests=["pytz"]) +extras = dict( + geography=["GeoAlchemy2", "shapely"], + alembic=["alembic"], + tests=["packaging", "pytz"], +) extras["all"] = set(itertools.chain.from_iterable(extras.values())) setup( @@ -85,7 +89,7 @@ def readme(): ], extras_require=extras, python_requires=">=3.6, <3.10", - tests_require=["pytz"], + tests_require=["packaging", "pytz"], entry_points={ "sqlalchemy.dialects": ["bigquery = sqlalchemy_bigquery:BigQueryDialect"] }, diff --git a/sqlalchemy_bigquery/__init__.py b/sqlalchemy_bigquery/__init__.py index a9321ed5..f0defda1 100644 --- a/sqlalchemy_bigquery/__init__.py +++ b/sqlalchemy_bigquery/__init__.py @@ -24,40 +24,42 @@ from .base import BigQueryDialect, dialect # noqa from .base import ( - STRING, + ARRAY, + BIGNUMERIC, BOOL, BOOLEAN, + BYTES, + DATE, + DATETIME, + FLOAT, + FLOAT64, INT64, INTEGER, - FLOAT64, - FLOAT, - TIMESTAMP, - DATETIME, - DATE, - BYTES, - TIME, - RECORD, NUMERIC, - BIGNUMERIC, + RECORD, + STRING, + TIME, + TIMESTAMP, ) __all__ = [ + "ARRAY", + "BIGNUMERIC", "BigQueryDialect", - "STRING", "BOOL", "BOOLEAN", + "BYTES", + "DATE", + "DATETIME", + "FLOAT", + "FLOAT64", "INT64", "INTEGER", - "FLOAT64", - "FLOAT", - "TIMESTAMP", - "DATETIME", - "DATE", - "BYTES", - "TIME", - "RECORD", "NUMERIC", - "BIGNUMERIC", + "RECORD", + "STRING", + "TIME", + "TIMESTAMP", ] try: diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index d1b0a75b..98edfb9e 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -62,6 +62,8 @@ FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+") +TABLE_VALUED_ALIAS_ALIASES = "bigquery_table_valued_alias_aliases" + def assert_(cond, message="Assertion failed"): # pragma: NO COVER if not cond: @@ -114,39 +116,41 @@ def format_label(self, label, name=None): _type_map = { - "STRING": types.String, - "BOOL": types.Boolean, + "ARRAY": types.ARRAY, + "BIGNUMERIC": types.Numeric, "BOOLEAN": types.Boolean, - "INT64": types.Integer, - "INTEGER": types.Integer, + "BOOL": types.Boolean, + "BYTES": types.BINARY, + "DATETIME": types.DATETIME, + "DATE": types.DATE, "FLOAT64": types.Float, "FLOAT": types.Float, + "INT64": types.Integer, + "INTEGER": types.Integer, + "NUMERIC": types.Numeric, + "RECORD": types.JSON, + "STRING": types.String, "TIMESTAMP": types.TIMESTAMP, - "DATETIME": types.DATETIME, - "DATE": types.DATE, - "BYTES": types.BINARY, "TIME": types.TIME, - "RECORD": types.JSON, - "NUMERIC": types.Numeric, - "BIGNUMERIC": types.Numeric, } # By convention, dialect-provided types are spelled with all upper case. -STRING = _type_map["STRING"] -BOOL = _type_map["BOOL"] +ARRAY = _type_map["ARRAY"] +BIGNUMERIC = _type_map["NUMERIC"] BOOLEAN = _type_map["BOOLEAN"] -INT64 = _type_map["INT64"] -INTEGER = _type_map["INTEGER"] +BOOL = _type_map["BOOL"] +BYTES = _type_map["BYTES"] +DATETIME = _type_map["DATETIME"] +DATE = _type_map["DATE"] FLOAT64 = _type_map["FLOAT64"] FLOAT = _type_map["FLOAT"] +INT64 = _type_map["INT64"] +INTEGER = _type_map["INTEGER"] +NUMERIC = _type_map["NUMERIC"] +RECORD = _type_map["RECORD"] +STRING = _type_map["STRING"] TIMESTAMP = _type_map["TIMESTAMP"] -DATETIME = _type_map["DATETIME"] -DATE = _type_map["DATE"] -BYTES = _type_map["BYTES"] TIME = _type_map["TIME"] -RECORD = _type_map["RECORD"] -NUMERIC = _type_map["NUMERIC"] -BIGNUMERIC = _type_map["NUMERIC"] try: _type_map["GEOGRAPHY"] = GEOGRAPHY @@ -246,6 +250,56 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw): insert_stmt, asfrom=False, **kw ) + def visit_table_valued_alias(self, element, **kw): + # When using table-valued functions, like UNNEST, BigQuery requires a + # FROM for any table referenced in the function, including expressions + # in function arguments. + # + # For example, given SQLAlchemy code: + # + # print( + # select([func.unnest(foo.c.objects).alias('foo_objects').column]) + # .compile(engine)) + # + # Left to it's own devices, SQLAlchemy would outout: + # + # SELECT `foo_objects` + # FROM unnest(`foo`.`objects`) AS `foo_objects` + # + # But BigQuery diesn't understand the `foo` reference unless + # we add as reference to `foo` in the FROM: + # + # SELECT foo_objects + # FROM `foo`, UNNEST(`foo`.`objects`) as foo_objects + # + # This is tricky because: + # 1. We have to find the table references. + # 2. We can't know practically if there's already a FROM for a table. + # + # We leverage visit_column to find a table reference. Whenever we find + # one, we create an alias for it, so as not to conflict with an existing + # reference if one is present. + # + # This requires communicating between this function and visit_column. + # We do this by sticking a dictionary in the keyword arguments. + # This dictionary: + # a. Tells visit_column that it's an a table-valued alias expresssion, and + # b. Gives it a place to record the aliases it creates. + # + # This function creates aliases in the FROM list for any aliases recorded + # by visit_column. + + kw[TABLE_VALUED_ALIAS_ALIASES] = {} + ret = super().visit_table_valued_alias(element, **kw) + aliases = kw.pop(TABLE_VALUED_ALIAS_ALIASES) + if aliases: + aliases = ", ".join( + f"{self.preparer.quote(tablename)} {self.preparer.quote(alias)}" + for tablename, alias in aliases.items() + ) + ret = f"{aliases}, {ret}" + return ret + def visit_column( self, column, @@ -281,6 +335,13 @@ def visit_column( tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) + elif TABLE_VALUED_ALIAS_ALIASES in kwargs: + aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES] + if tablename not in aliases: + aliases[tablename] = self.anon_map[ + f"{TABLE_VALUED_ALIAS_ALIASES} {tablename}" + ] + tablename = aliases[tablename] return self.preparer.quote(tablename) + "." + name diff --git a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py index c126c4f7..156e6167 100644 --- a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py +++ b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py @@ -19,6 +19,7 @@ import datetime import mock +import packaging.version import pytest import pytz import sqlalchemy @@ -41,7 +42,7 @@ ) -if sqlalchemy.__version__ < "1.4": +if packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"): from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest class LimitOffsetTest(_LimitOffsetTest): diff --git a/tests/system/test_sqlalchemy_bigquery.py b/tests/system/test_sqlalchemy_bigquery.py index 1390d11c..63dc220b 100644 --- a/tests/system/test_sqlalchemy_bigquery.py +++ b/tests/system/test_sqlalchemy_bigquery.py @@ -28,13 +28,13 @@ from sqlalchemy.sql import expression, select, literal_column from sqlalchemy.exc import NoSuchTableError from sqlalchemy.orm import sessionmaker +import packaging.version from pytz import timezone import pytest import sqlalchemy import datetime import decimal - ONE_ROW_CONTENTS_EXPANDED = [ 588, datetime.datetime(2013, 10, 10, 11, 27, 16, tzinfo=timezone("UTC")), @@ -725,3 +725,31 @@ class MyTable(Base): ) assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected + + +@pytest.mark.skipif( + packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"), + reason="unnest (and other table-valued-function) support required version 1.4", +) +def test_unnest(engine, bigquery_dataset): + from sqlalchemy import select, func, String + from sqlalchemy_bigquery import ARRAY + + conn = engine.connect() + metadata = MetaData() + table = Table( + f"{bigquery_dataset}.test_unnest", metadata, Column("objects", ARRAY(String)), + ) + metadata.create_all(engine) + conn.execute( + table.insert(), [dict(objects=["a", "b", "c"]), dict(objects=["x", "y"])] + ) + query = select([func.unnest(table.c.objects).alias("foo_objects").column]) + compiled = str(query.compile(engine)) + assert " ".join(compiled.strip().split()) == ( + f"SELECT `foo_objects`" + f" FROM" + f" `{bigquery_dataset}.test_unnest` `{bigquery_dataset}.test_unnest_1`," + f" unnest(`{bigquery_dataset}.test_unnest_1`.`objects`) AS `foo_objects`" + ) + assert sorted(r[0] for r in conn.execute(query)) == ["a", "b", "c", "x", "y"] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index e5de882d..886e9aee 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -21,20 +21,24 @@ import mock import sqlite3 +import packaging.version import pytest import sqlalchemy import fauxdbi -sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split("."))) +sqlalchemy_version = packaging.version.parse(sqlalchemy.__version__) sqlalchemy_1_3_or_higher = pytest.mark.skipif( - sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher" + sqlalchemy_version < packaging.version.parse("1.3"), + reason="requires sqlalchemy 1.3 or higher", ) sqlalchemy_1_4_or_higher = pytest.mark.skipif( - sqlalchemy_version_info < (1, 4), reason="requires sqlalchemy 1.4 or higher" + sqlalchemy_version < packaging.version.parse("1.4"), + reason="requires sqlalchemy 1.4 or higher", ) sqlalchemy_before_1_4 = pytest.mark.skipif( - sqlalchemy_version_info >= (1, 4), reason="requires sqlalchemy 1.3 or lower" + sqlalchemy_version >= packaging.version.parse("1.4"), + reason="requires sqlalchemy 1.3 or lower", ) diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index ad61d5a8..10669864 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -20,6 +20,7 @@ import datetime from decimal import Decimal +import packaging.version import pytest import sqlalchemy @@ -292,7 +293,10 @@ def test_select_in_param_empty(faux_conn): assert not isin assert faux_conn.test_data["execute"][-1] == ( "SELECT %(param_1:INT64)s IN(NULL) AND (1 != 1) AS `anon_1`" - if sqlalchemy.__version__ >= "1.4" + if ( + packaging.version.parse(sqlalchemy.__version__) + >= packaging.version.parse("1.4") + ) else "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1}, ) @@ -352,7 +356,10 @@ def test_select_notin_param_empty(faux_conn): assert isnotin assert faux_conn.test_data["execute"][-1] == ( "SELECT (%(param_1:INT64)s NOT IN(NULL) OR (1 = 1)) AS `anon_1`" - if sqlalchemy.__version__ >= "1.4" + if ( + packaging.version.parse(sqlalchemy.__version__) + >= packaging.version.parse("1.4") + ) else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1}, ) @@ -374,3 +381,54 @@ def nstr(q): nstr(q.compile(faux_conn.engine, compile_kwargs={"literal_binds": True})) == "SELECT `test`.`val` FROM `test` WHERE `test`.`val` IN (2)" ) + + +@sqlalchemy_1_4_or_higher +@pytest.mark.parametrize("alias", [True, False]) +def test_unnest(faux_conn, alias): + from sqlalchemy import String + from sqlalchemy_bigquery import ARRAY + + table = setup_table(faux_conn, "t", sqlalchemy.Column("objects", ARRAY(String))) + fcall = sqlalchemy.func.unnest(table.c.objects) + if alias: + query = fcall.alias("foo_objects").column + else: + query = fcall.column_valued("foo_objects") + compiled = str(sqlalchemy.select(query).compile(faux_conn.engine)) + assert " ".join(compiled.strip().split()) == ( + "SELECT `foo_objects` FROM `t` `t_1`, unnest(`t_1`.`objects`) AS `foo_objects`" + ) + + +@sqlalchemy_1_4_or_higher +@pytest.mark.parametrize("alias", [True, False]) +def test_table_valued_alias_w_multiple_references_to_the_same_table(faux_conn, alias): + from sqlalchemy import String + from sqlalchemy_bigquery import ARRAY + + table = setup_table(faux_conn, "t", sqlalchemy.Column("objects", ARRAY(String))) + fcall = sqlalchemy.func.foo(table.c.objects, table.c.objects) + if alias: + query = fcall.alias("foo_objects").column + else: + query = fcall.column_valued("foo_objects") + compiled = str(sqlalchemy.select(query).compile(faux_conn.engine)) + assert " ".join(compiled.strip().split()) == ( + "SELECT `foo_objects` " + "FROM `t` `t_1`, foo(`t_1`.`objects`, `t_1`.`objects`) AS `foo_objects`" + ) + + +@sqlalchemy_1_4_or_higher +@pytest.mark.parametrize("alias", [True, False]) +def test_unnest_w_no_table_references(faux_conn, alias): + fcall = sqlalchemy.func.unnest([1, 2, 3]) + if alias: + query = fcall.alias().column + else: + query = fcall.column_valued() + compiled = str(sqlalchemy.select(query).compile(faux_conn.engine)) + assert " ".join(compiled.strip().split()) == ( + "SELECT `anon_1` FROM unnest(%(unnest_1)s) AS `anon_1`" + )