Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: the unnest function lost needed type information #298

Merged
merged 17 commits into from Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions sqlalchemy_bigquery/base.py
Expand Up @@ -35,6 +35,8 @@
from google.api_core.exceptions import NotFound

import sqlalchemy
import sqlalchemy.sql.expression
import sqlalchemy.sql.functions
import sqlalchemy.sql.sqltypes
import sqlalchemy.sql.type_api
from sqlalchemy.exc import NoSuchTableError
Expand Down Expand Up @@ -1092,6 +1094,21 @@ def get_view_definition(self, connection, view_name, schema=None, **kw):
return view.view_query


class unnest(sqlalchemy.sql.functions.GenericFunction):
def __init__(self, *args, **kwargs):
expr = kwargs.pop("expr", None)
if expr is not None:
args = (expr,) + args
if len(args) != 1:
raise TypeError("The unnest function requires a single argument.")
arg = args[0]
if isinstance(arg, sqlalchemy.sql.expression.ColumnElement):
if not isinstance(arg.type, sqlalchemy.sql.sqltypes.ARRAY):
raise TypeError("The argument to unnest must have an ARRAY type.")
self.type = arg.type.item_type
super().__init__(*args, **kwargs)


dialect = BigQueryDialect

try:
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_sqlalchemy_bigquery.py
Expand Up @@ -10,6 +10,7 @@
from google.cloud import bigquery
from google.cloud.bigquery.dataset import DatasetListItem
from google.cloud.bigquery.table import TableListItem
import packaging.version
import pytest
import sqlalchemy

Expand Down Expand Up @@ -178,3 +179,49 @@ def test_follow_dialect_attribute_convention():

assert sqlalchemy_bigquery.dialect is sqlalchemy_bigquery.BigQueryDialect
assert sqlalchemy_bigquery.base.dialect is sqlalchemy_bigquery.BigQueryDialect


@pytest.mark.parametrize(
"args,kw,error",
[
((), {}, "The unnest function requires a single argument."),
((1, 1), {}, "The unnest function requires a single argument."),
((1,), {"expr": 1}, "The unnest function requires a single argument."),
((1, 1), {"expr": 1}, "The unnest function requires a single argument."),
(
(),
{"expr": sqlalchemy.Column("x", sqlalchemy.String)},
"The argument to unnest must have an ARRAY type.",
),
(
(sqlalchemy.Column("x", sqlalchemy.String),),
{},
"The argument to unnest must have an ARRAY type.",
),
],
)
def test_unnest_function_errors(args, kw, error):
import sqlalchemy_bigquery # noqa

with pytest.raises(TypeError, match=error):
sqlalchemy.func.unnest(*args, **kw)


@pytest.mark.parametrize(
"args,kw",
[
((), {"expr": sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String))}),
((sqlalchemy.Column("x", sqlalchemy.ARRAY(sqlalchemy.String)),), {}),
],
)
def test_unnest_function(args, kw):
import sqlalchemy_bigquery # noqa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this import to get background registrations done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. I'll add a comment to that effect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


f = sqlalchemy.func.unnest(*args, **kw)
assert isinstance(f.type, sqlalchemy.String)
if packaging.version.parse(sqlalchemy.__version__) >= packaging.version.parse(
"1.4"
):
assert isinstance(
sqlalchemy.select([f]).subquery().c.unnest.type, sqlalchemy.String
)