Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unnest failed in some cases (with table references failed when there were no other references to refrenced tables in a query) #290

Merged
merged 11 commits into from Aug 25, 2021
8 changes: 6 additions & 2 deletions setup.py
Expand Up @@ -45,7 +45,11 @@ def readme():
return f.read()


extras = dict(geography=["GeoAlchemy2", "shapely"], alembic=["alembic"], tests=["pytz"])
extras = dict(
geography=["GeoAlchemy2", "shapely"],
alembic=["alembic"],
tests=["packaging", "pytz"],
)
extras["all"] = set(itertools.chain.from_iterable(extras.values()))

setup(
Expand Down Expand Up @@ -85,7 +89,7 @@ def readme():
],
extras_require=extras,
python_requires=">=3.6, <3.10",
tests_require=["pytz"],
tests_require=["packaging", "pytz"],
entry_points={
"sqlalchemy.dialects": ["bigquery = sqlalchemy_bigquery:BigQueryDialect"]
},
Expand Down
42 changes: 22 additions & 20 deletions sqlalchemy_bigquery/__init__.py
Expand Up @@ -24,40 +24,42 @@

from .base import BigQueryDialect, dialect # noqa
from .base import (
STRING,
ARRAY,
BIGNUMERIC,
BOOL,
BOOLEAN,
BYTES,
DATE,
DATETIME,
FLOAT,
FLOAT64,
INT64,
INTEGER,
FLOAT64,
FLOAT,
TIMESTAMP,
DATETIME,
DATE,
BYTES,
TIME,
RECORD,
NUMERIC,
BIGNUMERIC,
RECORD,
STRING,
TIME,
TIMESTAMP,
)

__all__ = [
"ARRAY",
"BIGNUMERIC",
"BigQueryDialect",
"STRING",
"BOOL",
"BOOLEAN",
"BYTES",
"DATE",
"DATETIME",
"FLOAT",
"FLOAT64",
"INT64",
"INTEGER",
"FLOAT64",
"FLOAT",
"TIMESTAMP",
"DATETIME",
"DATE",
"BYTES",
"TIME",
"RECORD",
"NUMERIC",
"BIGNUMERIC",
"RECORD",
"STRING",
"TIME",
"TIMESTAMP",
]

try:
Expand Down
101 changes: 81 additions & 20 deletions sqlalchemy_bigquery/base.py
Expand Up @@ -62,6 +62,8 @@

FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+")

TABLE_VALUED_ALIAS_ALIASES = "bigquery_table_valued_alias_aliases"


def assert_(cond, message="Assertion failed"): # pragma: NO COVER
if not cond:
Expand Down Expand Up @@ -114,39 +116,41 @@ def format_label(self, label, name=None):


_type_map = {
"STRING": types.String,
"BOOL": types.Boolean,
"ARRAY": types.ARRAY,
"BIGNUMERIC": types.Numeric,
"BOOLEAN": types.Boolean,
"INT64": types.Integer,
"INTEGER": types.Integer,
"BOOL": types.Boolean,
"BYTES": types.BINARY,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"FLOAT64": types.Float,
"FLOAT": types.Float,
"INT64": types.Integer,
"INTEGER": types.Integer,
"NUMERIC": types.Numeric,
"RECORD": types.JSON,
"STRING": types.String,
"TIMESTAMP": types.TIMESTAMP,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"BYTES": types.BINARY,
"TIME": types.TIME,
"RECORD": types.JSON,
"NUMERIC": types.Numeric,
"BIGNUMERIC": types.Numeric,
}

# By convention, dialect-provided types are spelled with all upper case.
STRING = _type_map["STRING"]
BOOL = _type_map["BOOL"]
ARRAY = _type_map["ARRAY"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the new array type count as a feature? Perhaps it'd look better in the changelog as a separate PR?

If it is a feature and it doesn't make sense to split out, we could also try the multi-change commit message feature when we squash and merge this: https://github.com/googleapis/release-please#what-if-my-pr-contains-multiple-fixes-or-features

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a new feature. It was just missing from the types list.

SQLAlchemy has basic array support that works with BigQuery (thanks to some previous effort on our part).

BIGNUMERIC = _type_map["NUMERIC"]
BOOLEAN = _type_map["BOOLEAN"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
BOOL = _type_map["BOOL"]
BYTES = _type_map["BYTES"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
FLOAT64 = _type_map["FLOAT64"]
FLOAT = _type_map["FLOAT"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
NUMERIC = _type_map["NUMERIC"]
RECORD = _type_map["RECORD"]
STRING = _type_map["STRING"]
TIMESTAMP = _type_map["TIMESTAMP"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
BYTES = _type_map["BYTES"]
TIME = _type_map["TIME"]
RECORD = _type_map["RECORD"]
NUMERIC = _type_map["NUMERIC"]
BIGNUMERIC = _type_map["NUMERIC"]

try:
_type_map["GEOGRAPHY"] = GEOGRAPHY
Expand Down Expand Up @@ -246,6 +250,56 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
insert_stmt, asfrom=False, **kw
)

def visit_table_valued_alias(self, element, **kw):
# When using table-valued functions, like UNNEST, BigQuery requires a
# FROM for any table referenced in the function, including expressions
# in function arguments.
#
# For example, given SQLAlchemy code:
#
# print(
# select([func.unnest(foo.c.objects).alias('foo_objects').column])
# .compile(engine))
#
# Left to it's own devices, SQLAlchemy would outout:
#
# SELECT `foo_objects`
# FROM unnest(`foo`.`objects`) AS `foo_objects`
#
# But BigQuery diesn't understand the `foo` reference unless
# we add as reference to `foo` in the FROM:
#
# SELECT foo_objects
# FROM `foo`, UNNEST(`foo`.`objects`) as foo_objects
#
# This is tricky because:
# 1. We have to find the table references.
# 2. We can't know practically if there's already a FROM for a table.
#
# We leverage visit_column to find a table reference. Whenever we find
# one, we create an alias for it, so as not to conflict with an existing
# reference if one is present.
#
# This requires communicating between this function and visit_column.
# We do this by sticking a dictionary in the keyword arguments.
jimfulton marked this conversation as resolved.
Show resolved Hide resolved
# This dictionary:
# a. Tells visit_column that it's an a table-valued alias expresssion, and
# b. Gives it a place to record the aliases it creates.
#
# This function creates aliases in the FROM list for any aliases recorded
# by visit_column.

kw[TABLE_VALUED_ALIAS_ALIASES] = {}
ret = super().visit_table_valued_alias(element, **kw)
aliases = kw.pop(TABLE_VALUED_ALIAS_ALIASES)
if aliases:
aliases = ", ".join(
f"{self.preparer.quote(tablename)} {self.preparer.quote(alias)}"
for tablename, alias in aliases.items()
)
ret = f"{aliases}, {ret}"
return ret

def visit_column(
self,
column,
Expand Down Expand Up @@ -281,6 +335,13 @@ def visit_column(
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
elif TABLE_VALUED_ALIAS_ALIASES in kwargs:
aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES]
if tablename not in aliases:
aliases[tablename] = self.anon_map[
f"{TABLE_VALUED_ALIAS_ALIASES} {tablename}"
]
tablename = aliases[tablename]

return self.preparer.quote(tablename) + "." + name

Expand Down
Expand Up @@ -19,6 +19,7 @@

import datetime
import mock
import packaging.version
import pytest
import pytz
import sqlalchemy
Expand All @@ -41,7 +42,7 @@
)


if sqlalchemy.__version__ < "1.4":
if packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"):
from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest

class LimitOffsetTest(_LimitOffsetTest):
Expand Down
30 changes: 29 additions & 1 deletion tests/system/test_sqlalchemy_bigquery.py
Expand Up @@ -28,13 +28,13 @@
from sqlalchemy.sql import expression, select, literal_column
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import sessionmaker
import packaging.version
from pytz import timezone
import pytest
import sqlalchemy
import datetime
import decimal


ONE_ROW_CONTENTS_EXPANDED = [
588,
datetime.datetime(2013, 10, 10, 11, 27, 16, tzinfo=timezone("UTC")),
Expand Down Expand Up @@ -725,3 +725,31 @@ class MyTable(Base):
)

assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected


@pytest.mark.skipif(
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
reason="unnest (and other table-valued-function) support required version 1.4",
)
def test_unnest(engine, bigquery_dataset):
from sqlalchemy import select, func, String
from sqlalchemy_bigquery import ARRAY

conn = engine.connect()
metadata = MetaData()
table = Table(
f"{bigquery_dataset}.test_unnest", metadata, Column("objects", ARRAY(String)),
)
metadata.create_all(engine)
conn.execute(
table.insert(), [dict(objects=["a", "b", "c"]), dict(objects=["x", "y"])]
)
query = select([func.unnest(table.c.objects).alias("foo_objects").column])
compiled = str(query.compile(engine))
assert " ".join(compiled.strip().split()) == (
f"SELECT `foo_objects`"
f" FROM"
f" `{bigquery_dataset}.test_unnest` `{bigquery_dataset}.test_unnest_1`,"
f" unnest(`{bigquery_dataset}.test_unnest_1`.`objects`) AS `foo_objects`"
)
assert sorted(r[0] for r in conn.execute(query)) == ["a", "b", "c", "x", "y"]
12 changes: 8 additions & 4 deletions tests/unit/conftest.py
Expand Up @@ -21,20 +21,24 @@
import mock
import sqlite3

import packaging.version
import pytest
import sqlalchemy

import fauxdbi

sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split(".")))
sqlalchemy_version = packaging.version.parse(sqlalchemy.__version__)
sqlalchemy_1_3_or_higher = pytest.mark.skipif(
sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher"
sqlalchemy_version < packaging.version.parse("1.3"),
reason="requires sqlalchemy 1.3 or higher",
)
sqlalchemy_1_4_or_higher = pytest.mark.skipif(
sqlalchemy_version_info < (1, 4), reason="requires sqlalchemy 1.4 or higher"
sqlalchemy_version < packaging.version.parse("1.4"),
reason="requires sqlalchemy 1.4 or higher",
)
sqlalchemy_before_1_4 = pytest.mark.skipif(
sqlalchemy_version_info >= (1, 4), reason="requires sqlalchemy 1.3 or lower"
sqlalchemy_version >= packaging.version.parse("1.4"),
reason="requires sqlalchemy 1.3 or lower",
)


Expand Down