Skip to content

Commit

Permalink
fix: unnest failed in some cases (with table references failed when t…
Browse files Browse the repository at this point in the history
…here were no other references to refrenced tables in a query) (#290)
  • Loading branch information
jimfulton committed Aug 25, 2021
1 parent 5e9f4c2 commit 9b5b002
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 50 deletions.
8 changes: 6 additions & 2 deletions setup.py
Expand Up @@ -45,7 +45,11 @@ def readme():
return f.read()


extras = dict(geography=["GeoAlchemy2", "shapely"], alembic=["alembic"], tests=["pytz"])
extras = dict(
geography=["GeoAlchemy2", "shapely"],
alembic=["alembic"],
tests=["packaging", "pytz"],
)
extras["all"] = set(itertools.chain.from_iterable(extras.values()))

setup(
Expand Down Expand Up @@ -85,7 +89,7 @@ def readme():
],
extras_require=extras,
python_requires=">=3.6, <3.10",
tests_require=["pytz"],
tests_require=["packaging", "pytz"],
entry_points={
"sqlalchemy.dialects": ["bigquery = sqlalchemy_bigquery:BigQueryDialect"]
},
Expand Down
42 changes: 22 additions & 20 deletions sqlalchemy_bigquery/__init__.py
Expand Up @@ -24,40 +24,42 @@

from .base import BigQueryDialect, dialect # noqa
from .base import (
STRING,
ARRAY,
BIGNUMERIC,
BOOL,
BOOLEAN,
BYTES,
DATE,
DATETIME,
FLOAT,
FLOAT64,
INT64,
INTEGER,
FLOAT64,
FLOAT,
TIMESTAMP,
DATETIME,
DATE,
BYTES,
TIME,
RECORD,
NUMERIC,
BIGNUMERIC,
RECORD,
STRING,
TIME,
TIMESTAMP,
)

__all__ = [
"ARRAY",
"BIGNUMERIC",
"BigQueryDialect",
"STRING",
"BOOL",
"BOOLEAN",
"BYTES",
"DATE",
"DATETIME",
"FLOAT",
"FLOAT64",
"INT64",
"INTEGER",
"FLOAT64",
"FLOAT",
"TIMESTAMP",
"DATETIME",
"DATE",
"BYTES",
"TIME",
"RECORD",
"NUMERIC",
"BIGNUMERIC",
"RECORD",
"STRING",
"TIME",
"TIMESTAMP",
]

try:
Expand Down
101 changes: 81 additions & 20 deletions sqlalchemy_bigquery/base.py
Expand Up @@ -62,6 +62,8 @@

FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+")

TABLE_VALUED_ALIAS_ALIASES = "bigquery_table_valued_alias_aliases"


def assert_(cond, message="Assertion failed"): # pragma: NO COVER
if not cond:
Expand Down Expand Up @@ -114,39 +116,41 @@ def format_label(self, label, name=None):


_type_map = {
"STRING": types.String,
"BOOL": types.Boolean,
"ARRAY": types.ARRAY,
"BIGNUMERIC": types.Numeric,
"BOOLEAN": types.Boolean,
"INT64": types.Integer,
"INTEGER": types.Integer,
"BOOL": types.Boolean,
"BYTES": types.BINARY,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"FLOAT64": types.Float,
"FLOAT": types.Float,
"INT64": types.Integer,
"INTEGER": types.Integer,
"NUMERIC": types.Numeric,
"RECORD": types.JSON,
"STRING": types.String,
"TIMESTAMP": types.TIMESTAMP,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"BYTES": types.BINARY,
"TIME": types.TIME,
"RECORD": types.JSON,
"NUMERIC": types.Numeric,
"BIGNUMERIC": types.Numeric,
}

# By convention, dialect-provided types are spelled with all upper case.
STRING = _type_map["STRING"]
BOOL = _type_map["BOOL"]
ARRAY = _type_map["ARRAY"]
BIGNUMERIC = _type_map["NUMERIC"]
BOOLEAN = _type_map["BOOLEAN"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
BOOL = _type_map["BOOL"]
BYTES = _type_map["BYTES"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
FLOAT64 = _type_map["FLOAT64"]
FLOAT = _type_map["FLOAT"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
NUMERIC = _type_map["NUMERIC"]
RECORD = _type_map["RECORD"]
STRING = _type_map["STRING"]
TIMESTAMP = _type_map["TIMESTAMP"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
BYTES = _type_map["BYTES"]
TIME = _type_map["TIME"]
RECORD = _type_map["RECORD"]
NUMERIC = _type_map["NUMERIC"]
BIGNUMERIC = _type_map["NUMERIC"]

try:
_type_map["GEOGRAPHY"] = GEOGRAPHY
Expand Down Expand Up @@ -246,6 +250,56 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
insert_stmt, asfrom=False, **kw
)

def visit_table_valued_alias(self, element, **kw):
# When using table-valued functions, like UNNEST, BigQuery requires a
# FROM for any table referenced in the function, including expressions
# in function arguments.
#
# For example, given SQLAlchemy code:
#
# print(
# select([func.unnest(foo.c.objects).alias('foo_objects').column])
# .compile(engine))
#
# Left to it's own devices, SQLAlchemy would outout:
#
# SELECT `foo_objects`
# FROM unnest(`foo`.`objects`) AS `foo_objects`
#
# But BigQuery diesn't understand the `foo` reference unless
# we add as reference to `foo` in the FROM:
#
# SELECT foo_objects
# FROM `foo`, UNNEST(`foo`.`objects`) as foo_objects
#
# This is tricky because:
# 1. We have to find the table references.
# 2. We can't know practically if there's already a FROM for a table.
#
# We leverage visit_column to find a table reference. Whenever we find
# one, we create an alias for it, so as not to conflict with an existing
# reference if one is present.
#
# This requires communicating between this function and visit_column.
# We do this by sticking a dictionary in the keyword arguments.
# This dictionary:
# a. Tells visit_column that it's an a table-valued alias expresssion, and
# b. Gives it a place to record the aliases it creates.
#
# This function creates aliases in the FROM list for any aliases recorded
# by visit_column.

kw[TABLE_VALUED_ALIAS_ALIASES] = {}
ret = super().visit_table_valued_alias(element, **kw)
aliases = kw.pop(TABLE_VALUED_ALIAS_ALIASES)
if aliases:
aliases = ", ".join(
f"{self.preparer.quote(tablename)} {self.preparer.quote(alias)}"
for tablename, alias in aliases.items()
)
ret = f"{aliases}, {ret}"
return ret

def visit_column(
self,
column,
Expand Down Expand Up @@ -281,6 +335,13 @@ def visit_column(
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
elif TABLE_VALUED_ALIAS_ALIASES in kwargs:
aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES]
if tablename not in aliases:
aliases[tablename] = self.anon_map[
f"{TABLE_VALUED_ALIAS_ALIASES} {tablename}"
]
tablename = aliases[tablename]

return self.preparer.quote(tablename) + "." + name

Expand Down
Expand Up @@ -19,6 +19,7 @@

import datetime
import mock
import packaging.version
import pytest
import pytz
import sqlalchemy
Expand All @@ -41,7 +42,7 @@
)


if sqlalchemy.__version__ < "1.4":
if packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"):
from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest

class LimitOffsetTest(_LimitOffsetTest):
Expand Down
30 changes: 29 additions & 1 deletion tests/system/test_sqlalchemy_bigquery.py
Expand Up @@ -28,13 +28,13 @@
from sqlalchemy.sql import expression, select, literal_column
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import sessionmaker
import packaging.version
from pytz import timezone
import pytest
import sqlalchemy
import datetime
import decimal


ONE_ROW_CONTENTS_EXPANDED = [
588,
datetime.datetime(2013, 10, 10, 11, 27, 16, tzinfo=timezone("UTC")),
Expand Down Expand Up @@ -725,3 +725,31 @@ class MyTable(Base):
)

assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected


@pytest.mark.skipif(
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
reason="unnest (and other table-valued-function) support required version 1.4",
)
def test_unnest(engine, bigquery_dataset):
from sqlalchemy import select, func, String
from sqlalchemy_bigquery import ARRAY

conn = engine.connect()
metadata = MetaData()
table = Table(
f"{bigquery_dataset}.test_unnest", metadata, Column("objects", ARRAY(String)),
)
metadata.create_all(engine)
conn.execute(
table.insert(), [dict(objects=["a", "b", "c"]), dict(objects=["x", "y"])]
)
query = select([func.unnest(table.c.objects).alias("foo_objects").column])
compiled = str(query.compile(engine))
assert " ".join(compiled.strip().split()) == (
f"SELECT `foo_objects`"
f" FROM"
f" `{bigquery_dataset}.test_unnest` `{bigquery_dataset}.test_unnest_1`,"
f" unnest(`{bigquery_dataset}.test_unnest_1`.`objects`) AS `foo_objects`"
)
assert sorted(r[0] for r in conn.execute(query)) == ["a", "b", "c", "x", "y"]
12 changes: 8 additions & 4 deletions tests/unit/conftest.py
Expand Up @@ -21,20 +21,24 @@
import mock
import sqlite3

import packaging.version
import pytest
import sqlalchemy

import fauxdbi

sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split(".")))
sqlalchemy_version = packaging.version.parse(sqlalchemy.__version__)
sqlalchemy_1_3_or_higher = pytest.mark.skipif(
sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher"
sqlalchemy_version < packaging.version.parse("1.3"),
reason="requires sqlalchemy 1.3 or higher",
)
sqlalchemy_1_4_or_higher = pytest.mark.skipif(
sqlalchemy_version_info < (1, 4), reason="requires sqlalchemy 1.4 or higher"
sqlalchemy_version < packaging.version.parse("1.4"),
reason="requires sqlalchemy 1.4 or higher",
)
sqlalchemy_before_1_4 = pytest.mark.skipif(
sqlalchemy_version_info >= (1, 4), reason="requires sqlalchemy 1.3 or lower"
sqlalchemy_version >= packaging.version.parse("1.4"),
reason="requires sqlalchemy 1.3 or lower",
)


Expand Down

0 comments on commit 9b5b002

Please sign in to comment.