Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unnest failed in some cases (with table references failed when there were no other references to refrenced tables in a query) #290

Merged
merged 11 commits into from Aug 25, 2021
42 changes: 22 additions & 20 deletions sqlalchemy_bigquery/__init__.py
Expand Up @@ -24,40 +24,42 @@

from .base import BigQueryDialect, dialect # noqa
from .base import (
STRING,
ARRAY,
BIGNUMERIC,
BOOL,
BOOLEAN,
BYTES,
DATE,
DATETIME,
FLOAT,
FLOAT64,
INT64,
INTEGER,
FLOAT64,
FLOAT,
TIMESTAMP,
DATETIME,
DATE,
BYTES,
TIME,
RECORD,
NUMERIC,
BIGNUMERIC,
RECORD,
STRING,
TIME,
TIMESTAMP,
)

__all__ = [
"ARRAY",
"BIGNUMERIC",
"BigQueryDialect",
"STRING",
"BOOL",
"BOOLEAN",
"BYTES",
"DATE",
"DATETIME",
"FLOAT",
"FLOAT64",
"INT64",
"INTEGER",
"FLOAT64",
"FLOAT",
"TIMESTAMP",
"DATETIME",
"DATE",
"BYTES",
"TIME",
"RECORD",
"NUMERIC",
"BIGNUMERIC",
"RECORD",
"STRING",
"TIME",
"TIMESTAMP",
]

try:
Expand Down
101 changes: 81 additions & 20 deletions sqlalchemy_bigquery/base.py
Expand Up @@ -62,6 +62,8 @@

FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+")

TABLE_VALUED_ALIAS_ALIASES = "bigquery_table_valued_alias_aliases"


def assert_(cond, message="Assertion failed"): # pragma: NO COVER
if not cond:
Expand Down Expand Up @@ -114,39 +116,41 @@ def format_label(self, label, name=None):


_type_map = {
"STRING": types.String,
"BOOL": types.Boolean,
"ARRAY": types.ARRAY,
"BIGNUMERIC": types.Numeric,
"BOOLEAN": types.Boolean,
"INT64": types.Integer,
"INTEGER": types.Integer,
"BOOL": types.Boolean,
"BYTES": types.BINARY,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"FLOAT64": types.Float,
"FLOAT": types.Float,
"INT64": types.Integer,
"INTEGER": types.Integer,
"NUMERIC": types.Numeric,
"RECORD": types.JSON,
"STRING": types.String,
"TIMESTAMP": types.TIMESTAMP,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"BYTES": types.BINARY,
"TIME": types.TIME,
"RECORD": types.JSON,
"NUMERIC": types.Numeric,
"BIGNUMERIC": types.Numeric,
}

# By convention, dialect-provided types are spelled with all upper case.
STRING = _type_map["STRING"]
BOOL = _type_map["BOOL"]
ARRAY = _type_map["ARRAY"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the new array type count as a feature? Perhaps it'd look better in the changelog as a separate PR?

If it is a feature and it doesn't make sense to split out, we could also try the multi-change commit message feature when we squash and merge this: https://github.com/googleapis/release-please#what-if-my-pr-contains-multiple-fixes-or-features

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a new feature. It was just missing from the types list.

SQLAlchemy has basic array support that works with BigQuery (thanks to some previous effort on our part).

BIGNUMERIC = _type_map["NUMERIC"]
BOOLEAN = _type_map["BOOLEAN"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
BOOL = _type_map["BOOL"]
BYTES = _type_map["BYTES"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
FLOAT64 = _type_map["FLOAT64"]
FLOAT = _type_map["FLOAT"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
NUMERIC = _type_map["NUMERIC"]
RECORD = _type_map["RECORD"]
STRING = _type_map["STRING"]
TIMESTAMP = _type_map["TIMESTAMP"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
BYTES = _type_map["BYTES"]
TIME = _type_map["TIME"]
RECORD = _type_map["RECORD"]
NUMERIC = _type_map["NUMERIC"]
BIGNUMERIC = _type_map["NUMERIC"]

try:
_type_map["GEOGRAPHY"] = GEOGRAPHY
Expand Down Expand Up @@ -246,6 +250,56 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
insert_stmt, asfrom=False, **kw
)

def visit_table_valued_alias(self, element, **kw):
# When using table-valued functions, like UNNEST, BigQuery requires a
# FROM for any table referenced in the function, including expressions
# in function arguments.
#
# For example, given SQLAlchemy code:
#
# print(
# select([func.unnest(foo.c.objects).alias('foo_objects').column])
# .compile(engine))
#
# Left to it's own devices, SQLAlchemy would outout:
#
# SELECT `foo_objects`
# FROM unnest(`foo`.`objects`) AS `foo_objects`
#
# But BigQuery diesn't understand the `foo` reference unless
# we add as reference to `foo` in the FROM:
#
# SELECT foo_objects
# FROM `foo`, UNNEST(`foo`.`objects`) as foo_objects
#
# This is tricky because:
# 1. We have to find the table references.
# 2. We can't know practically if there's already a FROM for a table.
#
# We leverage visit_column to find a table reference. Whenever we find
# one, we create an alias for it, so as not to conlfict with an existing
jimfulton marked this conversation as resolved.
Show resolved Hide resolved
# reference if one is present.
#
# This requires communicating between this function and visit_column.
# We do this by sticking a dictionary in the keyword arguments.
jimfulton marked this conversation as resolved.
Show resolved Hide resolved
# This dictionary:
# a. Tells visit_column that it's an a table-valued alias expresssion, and
# b. Gives it a place to record the aliases it creates.
#
# This function creates aliases in the FROM list for any aliases recorded
# by visit_column.

kw[TABLE_VALUED_ALIAS_ALIASES] = {}
ret = super().visit_table_valued_alias(element, **kw)
aliases = kw.pop(TABLE_VALUED_ALIAS_ALIASES)
if aliases:
aliases = ", ".join(
f"{self.preparer.quote(tablename)} {self.preparer.quote(alias)}"
for tablename, alias in aliases.items()
)
ret = f"{aliases}, {ret}"
return ret

def visit_column(
self,
column,
Expand Down Expand Up @@ -281,6 +335,13 @@ def visit_column(
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
elif TABLE_VALUED_ALIAS_ALIASES in kwargs:
aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES]
if tablename not in aliases:
aliases[tablename] = self.anon_map[
f"{TABLE_VALUED_ALIAS_ALIASES} {tablename}"
]
tablename = aliases[tablename]

return self.preparer.quote(tablename) + "." + name

Expand Down
29 changes: 29 additions & 0 deletions tests/system/test_sqlalchemy_bigquery.py
Expand Up @@ -34,6 +34,7 @@
import datetime
import decimal

sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split(".")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works, though I've been trying to use https://packaging.pypa.io/en/latest/version.html which is the canonical version parser. We already have it pulled in via a transitive dependency through setuptools, I believe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I'll switch to that.

WRT how we get it, IMO, anything we import should be a direct dependency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I'm fine explicitly including it in our dependencies / test dependencies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


ONE_ROW_CONTENTS_EXPANDED = [
588,
Expand Down Expand Up @@ -725,3 +726,31 @@ class MyTable(Base):
)

assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected


@pytest.mark.skipif(
sqlalchemy_version_info < (1, 4),
reason="unnest (and other table-valued-function) support required version 1.4",
)
def test_unnest(engine, bigquery_dataset):
from sqlalchemy import select, func, String
from sqlalchemy_bigquery import ARRAY

conn = engine.connect()
metadata = MetaData()
table = Table(
f"{bigquery_dataset}.test_unnest", metadata, Column("objects", ARRAY(String)),
)
metadata.create_all(engine)
conn.execute(
table.insert(), [dict(objects=["a", "b", "c"]), dict(objects=["x", "y"])]
)
query = select([func.unnest(table.c.objects).alias("foo_objects").column])
compiled = str(query.compile(engine))
assert " ".join(compiled.strip().split()) == (
f"SELECT `foo_objects`"
f" FROM"
f" `{bigquery_dataset}.test_unnest` `{bigquery_dataset}.test_unnest_1`,"
f" unnest(`{bigquery_dataset}.test_unnest_1`.`objects`) AS `foo_objects`"
)
assert sorted(r[0] for r in conn.execute(query)) == ["a", "b", "c", "x", "y"]
51 changes: 51 additions & 0 deletions tests/unit/test_select.py
Expand Up @@ -374,3 +374,54 @@ def nstr(q):
nstr(q.compile(faux_conn.engine, compile_kwargs={"literal_binds": True}))
== "SELECT `test`.`val` FROM `test` WHERE `test`.`val` IN (2)"
)


@sqlalchemy_1_4_or_higher
@pytest.mark.parametrize("alias", [True, False])
def test_unnest(faux_conn, alias):
from sqlalchemy import String
from sqlalchemy_bigquery import ARRAY

table = setup_table(faux_conn, "t", sqlalchemy.Column("objects", ARRAY(String)))
fcall = sqlalchemy.func.unnest(table.c.objects)
if alias:
query = fcall.alias("foo_objects").column
else:
query = fcall.column_valued("foo_objects")
compiled = str(sqlalchemy.select(query).compile(faux_conn.engine))
assert " ".join(compiled.strip().split()) == (
"SELECT `foo_objects` FROM `t` `t_1`, unnest(`t_1`.`objects`) AS `foo_objects`"
)


@sqlalchemy_1_4_or_higher
@pytest.mark.parametrize("alias", [True, False])
def test_table_valued_alias_w_multiple_references_to_the_same_table(faux_conn, alias):
from sqlalchemy import String
from sqlalchemy_bigquery import ARRAY

table = setup_table(faux_conn, "t", sqlalchemy.Column("objects", ARRAY(String)))
fcall = sqlalchemy.func.foo(table.c.objects, table.c.objects)
if alias:
query = fcall.alias("foo_objects").column
else:
query = fcall.column_valued("foo_objects")
compiled = str(sqlalchemy.select(query).compile(faux_conn.engine))
assert " ".join(compiled.strip().split()) == (
"SELECT `foo_objects` "
"FROM `t` `t_1`, foo(`t_1`.`objects`, `t_1`.`objects`) AS `foo_objects`"
)


@sqlalchemy_1_4_or_higher
@pytest.mark.parametrize("alias", [True, False])
def test_unnest_w_no_table_references(faux_conn, alias):
fcall = sqlalchemy.func.unnest([1, 2, 3])
if alias:
query = fcall.alias().column
else:
query = fcall.column_valued()
compiled = str(sqlalchemy.select(query).compile(faux_conn.engine))
assert " ".join(compiled.strip().split()) == (
"SELECT `anon_1` FROM unnest(%(unnest_1)s) AS `anon_1`"
)