diff --git a/sqlalchemy_bigquery/base.py b/sqlalchemy_bigquery/base.py index 98edfb9e..e4f86e7b 100644 --- a/sqlalchemy_bigquery/base.py +++ b/sqlalchemy_bigquery/base.py @@ -483,6 +483,39 @@ def visit_bindparam( skip_bind_expression=False, **kwargs, ): + type_ = bindparam.type + unnest = False + if ( + bindparam.expanding + and not isinstance(type_, NullType) + and not literal_binds + ): + # Normally, when performing an IN operation, like: + # + # foo IN (some_sequence) + # + # SQAlchemy passes `foo` as a parameter and unpacks + # `some_sequence` and passes each element as a parameter. + # This mechanism is refered to as "expanding". It's + # inefficient and can't handle large arrays. (It's also + # very complicated, but that's not the issue we care about + # here. :) ) BigQuery lets us use arrays directly in this + # context, we just need to call UNNEST on an array when + # it's used in IN. + # + # So, if we get an `expanding` flag, and if we have a known type + # (and don't have literal binds, which are implemented in-line in + # in the SQL), we turn off expanding and we set an unnest flag + # so that we add an UNNEST() call (below). + # + # The NullType/known-type check has to do with some extreme + # edge cases having to do with empty in-lists that get special + # hijinks from SQLAlchemy that we don't want to disturb. :) + if getattr(bindparam, "expand_op", None) is not None: + assert bindparam.expand_op.__name__.endswith("in_op") # in in + bindparam.expanding = False + unnest = True + param = super(BigQueryCompiler, self).visit_bindparam( bindparam, within_columns_clause, @@ -491,7 +524,6 @@ def visit_bindparam( **kwargs, ) - type_ = bindparam.type if literal_binds or isinstance(type_, NullType): return param @@ -512,7 +544,6 @@ def visit_bindparam( if bq_type[-1] == ">" and bq_type.startswith("ARRAY<"): # Values get arrayified at a lower level. bq_type = bq_type[6:-1] - bq_type = self.__remove_type_parameter(bq_type) assert_(param != "%s", f"Unexpected param: {param}") @@ -528,6 +559,9 @@ def visit_bindparam( assert_(type_ is None) param = f"%({name}:{bq_type})s" + if unnest: + param = f"UNNEST({param})" + return param diff --git a/tests/system/test_sqlalchemy_bigquery.py b/tests/system/test_sqlalchemy_bigquery.py index 63dc220b..d8622020 100644 --- a/tests/system/test_sqlalchemy_bigquery.py +++ b/tests/system/test_sqlalchemy_bigquery.py @@ -727,6 +727,27 @@ class MyTable(Base): assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected +@pytest.mark.skipif( + packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"), + reason="requires sqlalchemy 1.4 or higher", +) +def test_huge_in(): + engine = sqlalchemy.create_engine("bigquery://") + conn = engine.connect() + try: + assert list( + conn.execute( + sqlalchemy.select([sqlalchemy.literal(-1).in_(list(range(99999)))]) + ) + ) == [(False,)] + except Exception: + error = True + else: + error = False + + assert not error, "execution failed" + + @pytest.mark.skipif( packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"), reason="unnest (and other table-valued-function) support required version 1.4", diff --git a/tests/unit/fauxdbi.py b/tests/unit/fauxdbi.py index 56c44e0f..631996af 100644 --- a/tests/unit/fauxdbi.py +++ b/tests/unit/fauxdbi.py @@ -261,11 +261,20 @@ def __handle_problematic_literal_inserts( else: return operation - __handle_unnest = substitute_string_re_method( - r"UNNEST\(\[ ([^\]]+)? \]\)", # UNNEST([ ... ]) - flags=re.IGNORECASE, - repl=r"(\1)", + @substitute_re_method( + r""" + UNNEST\( + ( + \[ (?P[^\]]+)? \] # UNNEST([ ... ]) + | + ([?]) # UNNEST(?) + ) + \) + """, + flags=re.IGNORECASE | re.VERBOSE, ) + def __handle_unnest(self, m): + return "(" + (m.group("exp") or "?") + ")" def __handle_true_false(self, operation): # Older sqlite versions, like those used on the CI servers diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index 10669864..474fc9d9 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -28,6 +28,7 @@ from conftest import ( setup_table, + sqlalchemy_version, sqlalchemy_1_3_or_higher, sqlalchemy_1_4_or_higher, sqlalchemy_before_1_4, @@ -214,18 +215,6 @@ def test_disable_quote(faux_conn): assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.foo \nFROM `t`") -def _normalize_in_params(query, params): - # We have to normalize parameter names, because they - # change with sqlalchemy versions. - newnames = sorted( - ((p, f"p_{i}") for i, p in enumerate(sorted(params))), key=lambda i: -len(i[0]) - ) - for old, new in newnames: - query = query.replace(old, new) - - return query, {new: params[old] for old, new in newnames} - - @sqlalchemy_before_1_4 def test_select_in_lit_13(faux_conn): [[isin]] = faux_conn.execute( @@ -240,66 +229,74 @@ def test_select_in_lit_13(faux_conn): @sqlalchemy_1_4_or_higher -def test_select_in_lit(faux_conn): - [[isin]] = faux_conn.execute( - sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]) - ) - assert isin - assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == ( - "SELECT %(p_0:INT64)s IN " - "UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ]) AS `anon_1`", - {"p_1": 1, "p_2": 2, "p_3": 3, "p_0": 1}, +def test_select_in_lit(faux_conn, last_query): + faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])])) + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(%(param_2:INT64)s) AS `anon_1`", + {"param_1": 1, "param_2": [1, 2, 3]}, ) -def test_select_in_param(faux_conn): +def test_select_in_param(faux_conn, last_query): [[isin]] = faux_conn.execute( sqlalchemy.select( [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] ), dict(q=[1, 2, 3]), ) - assert isin - assert faux_conn.test_data["execute"][-1] == ( - "SELECT %(param_1:INT64)s IN UNNEST(" - "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" - ") AS `anon_1`", - {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, - ) + if sqlalchemy_version >= packaging.version.parse("1.4"): + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", + {"param_1": 1, "q": [1, 2, 3]}, + ) + else: + assert isin + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(" + "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" + ") AS `anon_1`", + {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, + ) -def test_select_in_param1(faux_conn): +def test_select_in_param1(faux_conn, last_query): [[isin]] = faux_conn.execute( sqlalchemy.select( [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] ), dict(q=[1]), ) - assert isin - assert faux_conn.test_data["execute"][-1] == ( - "SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`", - {"param_1": 1, "q_1": 1}, - ) + if sqlalchemy_version >= packaging.version.parse("1.4"): + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", + {"param_1": 1, "q": [1]}, + ) + else: + assert isin + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`", + {"param_1": 1, "q_1": 1}, + ) @sqlalchemy_1_3_or_higher -def test_select_in_param_empty(faux_conn): +def test_select_in_param_empty(faux_conn, last_query): [[isin]] = faux_conn.execute( sqlalchemy.select( [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] ), dict(q=[]), ) - assert not isin - assert faux_conn.test_data["execute"][-1] == ( - "SELECT %(param_1:INT64)s IN(NULL) AND (1 != 1) AS `anon_1`" - if ( - packaging.version.parse(sqlalchemy.__version__) - >= packaging.version.parse("1.4") + if sqlalchemy_version >= packaging.version.parse("1.4"): + last_query( + "SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`", + {"param_1": 1, "q": []}, + ) + else: + assert not isin + last_query( + "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1} ) - else "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", - {"param_1": 1}, - ) @sqlalchemy_before_1_4 @@ -316,53 +313,54 @@ def test_select_notin_lit13(faux_conn): @sqlalchemy_1_4_or_higher -def test_select_notin_lit(faux_conn): - [[isnotin]] = faux_conn.execute( - sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]) +def test_select_notin_lit(faux_conn, last_query): + faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])])) + last_query( + "SELECT (%(param_1:INT64)s NOT IN UNNEST(%(param_2:INT64)s)) AS `anon_1`", + {"param_1": 0, "param_2": [1, 2, 3]}, ) - assert isnotin - assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == ( - "SELECT (%(p_0:INT64)s NOT IN " - "UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ])) AS `anon_1`", - {"p_0": 0, "p_1": 1, "p_2": 2, "p_3": 3}, - ) - -def test_select_notin_param(faux_conn): +def test_select_notin_param(faux_conn, last_query): [[isnotin]] = faux_conn.execute( sqlalchemy.select( [sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))] ), dict(q=[1, 2, 3]), ) - assert not isnotin - assert faux_conn.test_data["execute"][-1] == ( - "SELECT (%(param_1:INT64)s NOT IN UNNEST(" - "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" - ")) AS `anon_1`", - {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, - ) + if sqlalchemy_version >= packaging.version.parse("1.4"): + last_query( + "SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`", + {"param_1": 1, "q": [1, 2, 3]}, + ) + else: + assert not isnotin + last_query( + "SELECT (%(param_1:INT64)s NOT IN UNNEST(" + "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" + ")) AS `anon_1`", + {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, + ) @sqlalchemy_1_3_or_higher -def test_select_notin_param_empty(faux_conn): +def test_select_notin_param_empty(faux_conn, last_query): [[isnotin]] = faux_conn.execute( sqlalchemy.select( [sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))] ), dict(q=[]), ) - assert isnotin - assert faux_conn.test_data["execute"][-1] == ( - "SELECT (%(param_1:INT64)s NOT IN(NULL) OR (1 = 1)) AS `anon_1`" - if ( - packaging.version.parse(sqlalchemy.__version__) - >= packaging.version.parse("1.4") + if sqlalchemy_version >= packaging.version.parse("1.4"): + last_query( + "SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`", + {"param_1": 1, "q": []}, + ) + else: + assert isnotin + last_query( + "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1} ) - else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", - {"param_1": 1}, - ) def test_literal_binds_kwarg_with_an_IN_operator_252(faux_conn):