diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index e4f86e7b..e03d074e 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -35,6 +35,8 @@ from google.api_core.exceptions import NotFound import sqlalchemy +import sqlalchemy.sql.expression +import sqlalchemy.sql.functions import sqlalchemy.sql.sqltypes import sqlalchemy.sql.type_api from sqlalchemy.exc import NoSuchTableError @@ -1092,6 +1094,21 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): return view.view_query +class unnest(sqlalchemy.sql.functions.GenericFunction): + def __init__(self, *args, **kwargs): + expr = kwargs.pop("expr", None) + if expr is not None: + args = (expr,) + args + if len(args) != 1: + raise TypeError("The unnest function requires a single argument.") + arg = args[0] + if isinstance(arg, sqlalchemy.sql.expression.ColumnElement): + if not isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY): + raise TypeError("The argument to unnest must have an ARRAY type.") + self.type = arg.type.item_type + super().__init__(*args, **kwargs) + + dialect = BigQueryDialect try: diff --git a/tests/unit/test_sqlalchemy_bigquery.py b/tests/unit/test_sqlalchemy_bigquery.py index a4c81367..75cbec42 100644 --- a/tests/unit/test_sqlalchemy_bigquery.py +++ b/tests/unit/test_sqlalchemy_bigquery.py @@ -10,6 +10,7 @@ from google.cloud import bigquery from google.cloud.bigquery.dataset import DatasetListItem from google.cloud.bigquery.table import TableListItem +import packaging.version import pytest import sqlalchemy @@ -178,3 +179,53 @@ def test_follow_dialect_attribute_convention(): assert sqlalchemy_bigquery.dialect is sqlalchemy_bigquery.BigQueryDialect assert sqlalchemy_bigquery.base.dialect is sqlalchemy_bigquery.BigQueryDialect + + +@pytest.mark.parametrize( + "args,kw,error", + [ + ((), {}, "The unnest function requires a single argument."), + ((1, 1), {}, "The unnest function requires a single argument."), + ((1,), {"expr": 1}, "The unnest function requires a single argument."), + ((1, 1), {"expr": 1}, "The unnest function requires a single argument."), + ( + (), + {"expr": sqlalchemy.Column("x", sqlalchemy.String)}, + "The argument to unnest must have an ARRAY type.", + ), + ( + (sqlalchemy.Column("x", sqlalchemy.String),), + {}, + "The argument to unnest must have an ARRAY type.", + ), + ], +) +def test_unnest_function_errors(args, kw, error): + # Make sure the unnest function is registered with SQLAlchemy, which + # happens when sqlalchemy_bigquery is imported. + import sqlalchemy_bigquery # noqa + + with pytest.raises(TypeError, match=error): + sqlalchemy.func.unnest(*args, **kw) + + +@pytest.mark.parametrize( + "args,kw", + [ + ((), {"expr": sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String))}), + ((sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String)),), {}), + ], +) +def test_unnest_function(args, kw): + # Make sure the unnest function is registered with SQLAlchemy, which + # happens when sqlalchemy_bigquery is imported. + import sqlalchemy_bigquery # noqa + + f = sqlalchemy.func.unnest(*args, **kw) + assert isinstance(f.type, sqlalchemy.String) + if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse( + "1.4" + ): + assert isinstance( + sqlalchemy.select([f]).subquery().c.unnest.type, sqlalchemy.String + )