From 123318269876e7f76c7f0f2daa5f5b365026cd3f Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Thu, 26 Aug 2021 16:22:52 -0600 Subject: [PATCH] fix: the unnest function lost needed type information (#298) --- sqlalchemy_bigquery/base.py | 17 +++++++++ tests/unit/test_sqlalchemy_bigquery.py | 51 ++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index e4f86e7b..e03d074e 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -35,6 +35,8 @@ from google.api_core.exceptions import NotFound import sqlalchemy +import sqlalchemy.sql.expression +import sqlalchemy.sql.functions import sqlalchemy.sql.sqltypes import sqlalchemy.sql.type_api from sqlalchemy.exc import NoSuchTableError @@ -1092,6 +1094,21 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): return view.view_query +class unnest(sqlalchemy.sql.functions.GenericFunction): + def __init__(self, *args, **kwargs): + expr = kwargs.pop("expr", None) + if expr is not None: + args = (expr,) + args + if len(args) != 1: + raise TypeError("The unnest function requires a single argument.") + arg = args[0] + if isinstance(arg, sqlalchemy.sql.expression.ColumnElement): + if not isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY): + raise TypeError("The argument to unnest must have an ARRAY type.") + self.type = arg.type.item_type + super().__init__(*args, **kwargs) + + dialect = BigQueryDialect try: diff --git a/tests/unit/test_sqlalchemy_bigquery.py b/tests/unit/test_sqlalchemy_bigquery.py index a4c81367..75cbec42 100644 --- a/tests/unit/test_sqlalchemy_bigquery.py +++ b/tests/unit/test_sqlalchemy_bigquery.py @@ -10,6 +10,7 @@ from google.cloud import bigquery from google.cloud.bigquery.dataset import DatasetListItem from google.cloud.bigquery.table import TableListItem +import packaging.version import pytest import sqlalchemy @@ -178,3 +179,53 @@ def test_follow_dialect_attribute_convention(): assert sqlalchemy_bigquery.dialect is sqlalchemy_bigquery.BigQueryDialect assert sqlalchemy_bigquery.base.dialect is sqlalchemy_bigquery.BigQueryDialect + + +@pytest.mark.parametrize( + "args,kw,error", + [ + ((), {}, "The unnest function requires a single argument."), + ((1, 1), {}, "The unnest function requires a single argument."), + ((1,), {"expr": 1}, "The unnest function requires a single argument."), + ((1, 1), {"expr": 1}, "The unnest function requires a single argument."), + ( + (), + {"expr": sqlalchemy.Column("x", sqlalchemy.String)}, + "The argument to unnest must have an ARRAY type.", + ), + ( + (sqlalchemy.Column("x", sqlalchemy.String),), + {}, + "The argument to unnest must have an ARRAY type.", + ), + ], +) +def test_unnest_function_errors(args, kw, error): + # Make sure the unnest function is registered with SQLAlchemy, which + # happens when sqlalchemy_bigquery is imported. + import sqlalchemy_bigquery # noqa + + with pytest.raises(TypeError, match=error): + sqlalchemy.func.unnest(*args, **kw) + + +@pytest.mark.parametrize( + "args,kw", + [ + ((), {"expr": sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String))}), + ((sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String)),), {}), + ], +) +def test_unnest_function(args, kw): + # Make sure the unnest function is registered with SQLAlchemy, which + # happens when sqlalchemy_bigquery is imported. + import sqlalchemy_bigquery # noqa + + f = sqlalchemy.func.unnest(*args, **kw) + assert isinstance(f.type, sqlalchemy.String) + if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse( + "1.4" + ): + assert isinstance( + sqlalchemy.select([f]).subquery().c.unnest.type, sqlalchemy.String + )