From b7b60007c966cd548448d1d6fd5a14d1f89480cd Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Fri, 21 May 2021 11:33:46 -0600 Subject: [PATCH] feat: Add support for SQLAlchemy 1.4 (#177) --- noxfile.py | 5 +- pybigquery/_helpers.py | 22 +++ pybigquery/parse_url.py | 2 +- pybigquery/requirements.py | 16 +- pybigquery/sqlalchemy_bigquery.py | 110 +++++++---- setup.py | 6 +- testing/constraints-3.6.txt | 2 +- testing/constraints-3.8.txt | 1 + testing/constraints-3.9.txt | 1 + tests/conftest.py | 5 + .../sqlalchemy_dialect_compliance/conftest.py | 19 +- .../test_dialect_compliance.py | 171 +++++++++++++++--- tests/unit/conftest.py | 6 + tests/unit/fauxdbi.py | 113 ++++++------ tests/unit/test_compliance.py | 17 +- tests/unit/test_helpers.py | 56 ++++++ tests/unit/test_parse_url.py | 6 +- tests/unit/test_select.py | 68 ++++++- tests/unit/test_sqlalchemy_bigquery.py | 21 +++ 19 files changed, 499 insertions(+), 148 deletions(-) diff --git a/noxfile.py b/noxfile.py index 75a550c4..3a0007ba 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,7 +28,9 @@ BLACK_PATHS = ["docs", "pybigquery", "tests", "noxfile.py", "setup.py"] DEFAULT_PYTHON_VERSION = "3.8" -SYSTEM_TEST_PYTHON_VERSIONS = ["3.9"] + +# We're using two Python versions to test with sqlalchemy 1.3 and 1.4. +SYSTEM_TEST_PYTHON_VERSIONS = ["3.8", "3.9"] UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() @@ -47,6 +49,7 @@ # Error if a python version is missing nox.options.error_on_missing_interpreters = True +nox.options.stop_on_first_error = True @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/pybigquery/_helpers.py b/pybigquery/_helpers.py index fc48144c..a93e0c54 100644 --- a/pybigquery/_helpers.py +++ b/pybigquery/_helpers.py @@ -4,6 +4,9 @@ # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. +import functools +import re + from google.api_core import client_info import google.auth from google.cloud import bigquery @@ -58,3 +61,22 @@ def create_bigquery_client( location=location, default_query_job_config=default_query_job_config, ) + + +def substitute_re_method(r, flags=0, repl=None): + if repl is None: + return lambda f: substitute_re_method(r, flags, f) + + r = re.compile(r, flags) + + if isinstance(repl, str): + return lambda self, s: r.sub(repl, s) + + @functools.wraps(repl) + def sub(self, s, *args, **kw): + def repl_(m): + return repl(self, m, *args, **kw) + + return r.sub(repl_, s) + + return sub diff --git a/pybigquery/parse_url.py b/pybigquery/parse_url.py index 391ff2f1..13dda364 100644 --- a/pybigquery/parse_url.py +++ b/pybigquery/parse_url.py @@ -44,7 +44,7 @@ def parse_boolean(bool_string): def parse_url(url): # noqa: C901 - query = url.query + query = dict(url.query) # need mutable query. # use_legacy_sql (legacy) if "use_legacy_sql" in query: diff --git a/pybigquery/requirements.py b/pybigquery/requirements.py index 7621cdea..0be21a85 100644 --- a/pybigquery/requirements.py +++ b/pybigquery/requirements.py @@ -154,8 +154,14 @@ def comment_reflection(self): def unicode_ddl(self): """Target driver must support some degree of non-ascii symbol names. + + However: + + Must contain only letters (a-z, A-Z), numbers (0-9), or underscores (_) + + https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#column_name_and_column_schema """ - return supported() + return unsupported() @property def datetime_literals(self): @@ -219,6 +225,14 @@ def order_by_label_with_expression(self): """ return supported() + @property + def sql_expression_limit_offset(self): + """target database can render LIMIT and/or OFFSET with a complete + SQL expression, such as one that uses the addition operator. + parameter + """ + return unsupported() + class WithSchemas(Requirements): """ diff --git a/pybigquery/sqlalchemy_bigquery.py b/pybigquery/sqlalchemy_bigquery.py index 764c3fc0..7ef2d725 100644 --- a/pybigquery/sqlalchemy_bigquery.py +++ b/pybigquery/sqlalchemy_bigquery.py @@ -34,6 +34,7 @@ from google.cloud.bigquery.table import TableReference from google.api_core.exceptions import NotFound +import sqlalchemy import sqlalchemy.sql.sqltypes import sqlalchemy.sql.type_api from sqlalchemy.exc import NoSuchTableError @@ -57,6 +58,11 @@ FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+") +def assert_(cond, message="Assertion failed"): # pragma: NO COVER + if not cond: + raise AssertionError(message) + + class BigQueryIdentifierPreparer(IdentifierPreparer): """ Set containing everything @@ -152,15 +158,25 @@ def get_insert_default(self, column): # pragma: NO COVER elif isinstance(column.type, String): return str(uuid.uuid4()) - def pre_exec( - self, - in_sub=re.compile( - r" IN UNNEST\(\[ " - r"(%\([^)]+_\d+\)s(?:, %\([^)]+_\d+\)s)*)?" # Placeholders. See below. - r":([A-Z0-9]+)" # Type - r" \]\)" - ).sub, - ): + __remove_type_from_empty_in = _helpers.substitute_re_method( + r" IN UNNEST\(\[ (" + r"(?:NULL|\(NULL(?:, NULL)+\))\)" + r" (?:AND|OR) \(1 !?= 1" + r")" + r"(?:[:][A-Z0-9]+)?" + r" \]\)", + re.IGNORECASE, + r" IN(\1)", + ) + + @_helpers.substitute_re_method( + r" IN UNNEST\(\[ " + r"(%\([^)]+_\d+\)s(?:, %\([^)]+_\d+\)s)*)?" # Placeholders. See below. + r":([A-Z0-9]+)" # Type + r" \]\)", + re.IGNORECASE, + ) + def __distribute_types_to_expanded_placeholders(self, m): # If we have an in parameter, it sometimes gets expaned to 0 or more # parameters and we need to move the type marker to each # parameter. @@ -171,29 +187,29 @@ def pre_exec( # suffixes refect that when an array parameter is expanded, # numeric suffixes are added. For example, a placeholder like # `%(foo)s` gets expaneded to `%(foo_0)s, `%(foo_1)s, ...`. + placeholders, type_ = m.groups() + if placeholders: + placeholders = placeholders.replace(")", f":{type_})") + else: + placeholders = "" + return f" IN UNNEST([ {placeholders} ])" - def repl(m): - placeholders, type_ = m.groups() - if placeholders: - placeholders = placeholders.replace(")", f":{type_})") - else: - placeholders = "" - return f" IN UNNEST([ {placeholders} ])" - - self.statement = in_sub(repl, self.statement) + def pre_exec(self): + self.statement = self.__distribute_types_to_expanded_placeholders( + self.__remove_type_from_empty_in(self.statement) + ) class BigQueryCompiler(SQLCompiler): compound_keywords = SQLCompiler.compound_keywords.copy() - compound_keywords[selectable.CompoundSelect.UNION] = "UNION ALL" + compound_keywords[selectable.CompoundSelect.UNION] = "UNION DISTINCT" + compound_keywords[selectable.CompoundSelect.UNION_ALL] = "UNION ALL" - def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): + def __init__(self, dialect, statement, *args, **kwargs): if isinstance(statement, Column): kwargs["compile_kwargs"] = util.immutabledict({"include_table": False}) - super(BigQueryCompiler, self).__init__( - dialect, statement, column_keys, inline, **kwargs - ) + super(BigQueryCompiler, self).__init__(dialect, statement, *args, **kwargs) def visit_insert(self, insert_stmt, asfrom=False, **kw): # The (internal) documentation for `inline` is confusing, but @@ -260,24 +276,37 @@ def group_by_clause(self, select, **kw): # no way to tell sqlalchemy that, so it works harder than # necessary and makes us do the same. - _in_expanding_bind = re.compile(r" IN \((\[EXPANDING_\w+\](:[A-Z0-9]+)?)\)$") + __sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split("."))) - def _unnestify_in_expanding_bind(self, in_text): - return self._in_expanding_bind.sub(r" IN UNNEST([ \1 ])", in_text) + __expandng_text = ( + "EXPANDING" if __sqlalchemy_version_info < (1, 4) else "POSTCOMPILE" + ) + + __in_expanding_bind = _helpers.substitute_re_method( + fr" IN \((\[" fr"{__expandng_text}" fr"_[^\]]+\](:[A-Z0-9]+)?)\)$", + re.IGNORECASE, + r" IN UNNEST([ \1 ])", + ) def visit_in_op_binary(self, binary, operator_, **kw): - return self._unnestify_in_expanding_bind( + return self.__in_expanding_bind( self._generate_generic_binary(binary, " IN ", **kw) ) def visit_empty_set_expr(self, element_types): return "" - def visit_notin_op_binary(self, binary, operator, **kw): - return self._unnestify_in_expanding_bind( - self._generate_generic_binary(binary, " NOT IN ", **kw) + def visit_not_in_op_binary(self, binary, operator, **kw): + return ( + "(" + + self.__in_expanding_bind( + self._generate_generic_binary(binary, " NOT IN ", **kw) + ) + + ")" ) + visit_notin_op_binary = visit_not_in_op_binary # before 1.4 + ############################################################################ ############################################################################ @@ -327,6 +356,10 @@ def visit_notendswith_op_binary(self, binary, operator, **kw): ############################################################################ + __placeholder = re.compile(r"%\(([^\]:]+)(:[^\]:]+)?\)s$").match + + __expanded_param = re.compile(fr"\(\[" fr"{__expandng_text}" fr"_[^\]]+\]\)$").match + def visit_bindparam( self, bindparam, @@ -365,8 +398,20 @@ def visit_bindparam( # Values get arrayified at a lower level. bq_type = bq_type[6:-1] - assert param != "%s" - return param.replace(")", f":{bq_type})") + assert_(param != "%s", f"Unexpected param: {param}") + + if bindparam.expanding: + assert_(self.__expanded_param(param), f"Unexpected param: {param}") + param = param.replace(")", f":{bq_type})") + + else: + m = self.__placeholder(param) + if m: + name, type_ = m.groups() + assert_(type_ is None) + param = f"%({name}:{bq_type})s" + + return param class BigQueryTypeCompiler(GenericTypeCompiler): @@ -541,7 +586,6 @@ class BigQueryDialect(DefaultDialect): supports_unicode_statements = True supports_unicode_binds = True supports_native_decimal = True - returns_unicode_strings = True description_encoding = None supports_native_boolean = True supports_simple_order_by_label = True diff --git a/setup.py b/setup.py index a417129d..ac0bcb91 100644 --- a/setup.py +++ b/setup.py @@ -65,10 +65,10 @@ def readme(): ], platforms="Posix; MacOS X; Windows", install_requires=[ - "sqlalchemy>=1.2.0,<1.4.0dev", - "google-auth>=1.24.0,<2.0dev", # Work around pip wack. - "google-cloud-bigquery>=2.15.0", "google-api-core>=1.23.0", # Work-around bug in cloud core deps. + "google-auth>=1.24.0,<2.0dev", # Work around pip wack. + "google-cloud-bigquery>=2.16.1", + "sqlalchemy>=1.2.0,<1.5.0dev", "future", ], python_requires=">=3.6, <3.10", diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 5bc8ccf5..b975c9ea 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -6,5 +6,5 @@ # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", sqlalchemy==1.2.0 google-auth==1.24.0 -google-cloud-bigquery==2.15.0 +google-cloud-bigquery==2.16.1 google-api-core==1.23.0 diff --git a/testing/constraints-3.8.txt b/testing/constraints-3.8.txt index e69de29b..4884f96a 100644 --- a/testing/constraints-3.8.txt +++ b/testing/constraints-3.8.txt @@ -0,0 +1 @@ +sqlalchemy==1.3.24 diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index e69de29b..eebb9da6 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -0,0 +1 @@ +sqlalchemy>=1.4.13 diff --git a/tests/conftest.py b/tests/conftest.py index 2a7dcc4c..3ddf981e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,3 +20,8 @@ from sqlalchemy.dialects import registry registry.register("bigquery", "pybigquery.sqlalchemy_bigquery", "BigQueryDialect") + +# sqlalchemy's dialect-testing machinery wants an entry like this. It is wack. :( +registry.register( + "bigquery.bigquery", "pybigquery.sqlalchemy_bigquery", "BigQueryDialect" +) diff --git a/tests/sqlalchemy_dialect_compliance/conftest.py b/tests/sqlalchemy_dialect_compliance/conftest.py index 47752dde..a0fa5e62 100644 --- a/tests/sqlalchemy_dialect_compliance/conftest.py +++ b/tests/sqlalchemy_dialect_compliance/conftest.py @@ -19,9 +19,9 @@ import contextlib import random +import re import traceback -import sqlalchemy from sqlalchemy.testing import config from sqlalchemy.testing.plugin.pytestplugin import * # noqa from sqlalchemy.testing.plugin.pytestplugin import ( @@ -35,6 +35,7 @@ pybigquery.sqlalchemy_bigquery.BigQueryDialect.preexecute_autoincrement_sequences = True google.cloud.bigquery.dbapi.connection.Connection.rollback = lambda self: None +_where = re.compile(r"\s+WHERE\s+", re.IGNORECASE).search # BigQuery requires delete statements to have where clauses. Other # databases don't and sqlalchemy doesn't include where clauses when @@ -42,16 +43,20 @@ # where clause when tearing down tests. We only do this during tear # down, by inspecting the stack, because we don't want to hide bugs # outside of test house-keeping. -def visit_delete(self, delete_stmt, *args, **kw): - if delete_stmt._whereclause is None and "teardown" in set( - f.name for f in traceback.extract_stack() - ): - delete_stmt._whereclause = sqlalchemy.true() - return super(pybigquery.sqlalchemy_bigquery.BigQueryCompiler, self).visit_delete( + +def visit_delete(self, delete_stmt, *args, **kw): + text = super(pybigquery.sqlalchemy_bigquery.BigQueryCompiler, self).visit_delete( delete_stmt, *args, **kw ) + if not _where(text) and any( + "teardown" in f.name.lower() for f in traceback.extract_stack() + ): + text += " WHERE true" + + return text + pybigquery.sqlalchemy_bigquery.BigQueryCompiler.visit_delete = visit_delete diff --git a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py index 259a78ec..e03e0b22 100644 --- a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py +++ b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py @@ -21,7 +21,11 @@ import mock import pytest import pytz +import sqlalchemy from sqlalchemy import and_ + +import sqlalchemy.testing.suite.test_types +from sqlalchemy.testing import util from sqlalchemy.testing.assertions import eq_ from sqlalchemy.testing.suite import config, select, exists from sqlalchemy.testing.suite import * # noqa @@ -30,21 +34,154 @@ CTETest as _CTETest, ExistsTest as _ExistsTest, InsertBehaviorTest as _InsertBehaviorTest, - LimitOffsetTest as _LimitOffsetTest, LongNameBlowoutTest, QuotedNameArgumentTest, SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest, TimestampMicrosecondsTest as _TimestampMicrosecondsTest, ) + +if sqlalchemy.__version__ < "1.4": + from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest + + class LimitOffsetTest(_LimitOffsetTest): + @pytest.mark.skip("BigQuery doesn't allow an offset without a limit.") + def test_simple_offset(self): + pass + + test_bound_offset = test_simple_offset + + class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): + + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC) + + def test_literal(self): + # The base tests doesn't set up the literal properly, because + # it doesn't pass its datatype to `literal`. + + def literal(value): + assert value == self.data + import sqlalchemy.sql.sqltypes + + return sqlalchemy.sql.elements.literal(value, self.datatype) + + with mock.patch("sqlalchemy.testing.suite.test_types.literal", literal): + super(TimestampMicrosecondsTest, self).test_literal() + + +else: + from sqlalchemy.testing.suite import ( + ComponentReflectionTestExtra as _ComponentReflectionTestExtra, + FetchLimitOffsetTest as _FetchLimitOffsetTest, + RowCountTest as _RowCountTest, + ) + + class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip("BigQuery doesn't allow an offset without a limit.") + def test_simple_offset(self): + pass + + test_bound_offset = test_simple_offset + test_expr_offset = test_simple_offset_zero = test_simple_offset + + # The original test is missing an order by. + + # Also, note that sqlalchemy union is a union distinct, not a + # union all. This test caught that were were getting that wrong. + def test_limit_render_multiple_times(self, connection): + table = self.tables.some_table + stmt = select(table.c.id).order_by(table.c.id).limit(1).scalar_subquery() + + u = sqlalchemy.union(select(stmt), select(stmt)).subquery().select() + + self._assert_result( + connection, u, [(1,)], + ) + + del DifficultParametersTest # exercises column names illegal in BQ + del DistinctOnTest # expects unquoted table names. + del HasIndexTest # BQ doesn't do the indexes that SQLA is loooking for. + del IdentityAutoincrementTest # BQ doesn't do autoincrement + + # This test makes makes assertions about generated sql and trips + # over the backquotes that we add everywhere. XXX Why do we do that? + del PostCompileParamsTest + + class ComponentReflectionTestExtra(_ComponentReflectionTestExtra): + @pytest.mark.skip("BQ types don't have parameters like precision and length") + def test_numeric_reflection(self): + pass + + test_varchar_reflection = test_numeric_reflection + + class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): + + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC) + + def test_literal(self, literal_round_trip): + # The base tests doesn't set up the literal properly, because + # it doesn't pass its datatype to `literal`. + + def literal(value): + assert value == self.data + import sqlalchemy.sql.sqltypes + + return sqlalchemy.sql.elements.literal(value, self.datatype) + + with mock.patch("sqlalchemy.testing.suite.test_types.literal", literal): + super(TimestampMicrosecondsTest, self).test_literal(literal_round_trip) + + def test_round_trip_executemany(self, connection): + unicode_table = self.tables.unicode_table + connection.execute( + unicode_table.insert(), + [{"id": i, "unicode_data": self.data} for i in range(3)], + ) + + rows = connection.execute(select(unicode_table.c.unicode_data)).fetchall() + eq_(rows, [(self.data,) for i in range(3)]) + for row in rows: + assert isinstance(row[0], util.text_type) + + sqlalchemy.testing.suite.test_types._UnicodeFixture.test_round_trip_executemany = ( + test_round_trip_executemany + ) + + class RowCountTest(_RowCountTest): + @classmethod + def insert_data(cls, connection): + cls.data = data = [ + ("Angela", "A"), + ("Andrew", "A"), + ("Anand", "A"), + ("Bob", "B"), + ("Bobette", "B"), + ("Buffy", "B"), + ("Charlie", "C"), + ("Cynthia", "C"), + ("Chris", "C"), + ] + + employees_table = cls.tables.employees + connection.execute( + employees_table.insert(), + [ + {"employee_id": i, "name": n, "department": d} + for i, (n, d) in enumerate(data) + ], + ) + + # Quotes aren't allowed in BigQuery table names. del QuotedNameArgumentTest class InsertBehaviorTest(_InsertBehaviorTest): - @pytest.mark.skip() + @pytest.mark.skip( + "BQ has no autoinc and client-side defaults can't work for select." + ) def test_insert_from_select_autoinc(cls): - """BQ has no autoinc and client-side defaults can't work for select.""" + pass class ExistsTest(_ExistsTest): @@ -76,14 +213,6 @@ def test_select_exists_false(self, connection): ) -class LimitOffsetTest(_LimitOffsetTest): - @pytest.mark.skip() - def test_simple_offset(self): - """BigQuery doesn't allow an offset without a limit.""" - - test_bound_offset = test_simple_offset - - # This test requires features (indexes, primary keys, etc., that BigQuery doesn't have. del LongNameBlowoutTest @@ -130,20 +259,6 @@ def course_grained_types(): test_numeric_reflection = test_varchar_reflection = course_grained_types - -class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): - - data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC) - - def test_literal(self): - # The base tests doesn't set up the literal properly, because - # it doesn't pass its datatype to `literal`. - - def literal(value): - assert value == self.data - import sqlalchemy.sql.sqltypes - - return sqlalchemy.sql.elements.literal(value, self.datatype) - - with mock.patch("sqlalchemy.testing.suite.test_types.literal", literal): - super(TimestampMicrosecondsTest, self).test_literal() + @pytest.mark.skip("BQ doesn't have indexes (in the way these tests expect).") + def test_get_indexes(self): + pass diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 801e84a9..aa23fe22 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -30,6 +30,12 @@ sqlalchemy_1_3_or_higher = pytest.mark.skipif( sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher" ) +sqlalchemy_1_4_or_higher = pytest.mark.skipif( + sqlalchemy_version_info < (1, 4), reason="requires sqlalchemy 1.4 or higher" +) +sqlalchemy_before_1_4 = pytest.mark.skipif( + sqlalchemy_version_info >= (1, 4), reason="requires sqlalchemy 1.3 or lower" +) @pytest.fixture() diff --git a/tests/unit/fauxdbi.py b/tests/unit/fauxdbi.py index 70cbb8aa..56652cd5 100644 --- a/tests/unit/fauxdbi.py +++ b/tests/unit/fauxdbi.py @@ -30,6 +30,8 @@ import google.cloud.bigquery.table import google.cloud.bigquery.dbapi.cursor +from pybigquery._helpers import substitute_re_method + class Connection: def __init__(self, connection, test_data, client, *args, **kw): @@ -85,23 +87,18 @@ def arraysize(self, v): datetime.time, ) - def __convert_params( - self, - operation, - parameters, - placeholder=re.compile(r"%\((\w+)\)s", re.IGNORECASE), - ): - ordered_parameters = [] - - def repl(m): - name = m.group(1) - value = parameters[name] - if isinstance(value, self._need_to_be_pickled): - value = pickle.dumps(value, 4).decode("latin1") - ordered_parameters.append(value) - return "?" + @substitute_re_method(r"%\((\w+)\)s", re.IGNORECASE) + def __pyformat_to_qmark(self, m, parameters, ordered_parameters): + name = m.group(1) + value = parameters[name] + if isinstance(value, self._need_to_be_pickled): + value = pickle.dumps(value, 4).decode("latin1") + ordered_parameters.append(value) + return "?" - operation = placeholder.sub(repl, operation) + def __convert_params(self, operation, parameters): + ordered_parameters = [] + operation = self.__pyformat_to_qmark(operation, parameters, ordered_parameters) return operation, ordered_parameters def __update_comment(self, table, col, comment): @@ -113,6 +110,23 @@ def __update_comment(self, table, col, comment): r"\s*create\s+table\s+`(?P\w+)`", re.IGNORECASE ).match + @substitute_re_method( + r"(?P`(?P\w+)`\s+\w+|\))" r"\s+options\((?P[^)]+)\)", + re.IGNORECASE, + ) + def __handle_column_comments(self, m, table_name): + col = m.group("col") or "" + options = { + name.strip().lower(): value.strip() + for name, value in (o.split("=") for o in m.group("options").split(",")) + } + + comment = options.get("description") + if comment: + self.__update_comment(table_name, col, comment) + + return m.group("prefix") + def __handle_comments( self, operation, @@ -121,31 +135,10 @@ def __handle_comments( r"SET\s+OPTIONS\(description=(?P[^)]+)\)", re.IGNORECASE, ).match, - options=re.compile( - r"(?P`(?P\w+)`\s+\w+|\))" r"\s+options\((?P[^)]+)\)", - re.IGNORECASE, - ), ): m = self.__create_table(operation) if m: - table_name = m.group("table") - - def repl(m): - col = m.group("col") or "" - options = { - name.strip().lower(): value.strip() - for name, value in ( - o.split("=") for o in m.group("options").split(",") - ) - } - - comment = options.get("description") - if comment: - self.__update_comment(table_name, col, comment) - - return m.group("prefix") - - return options.sub(repl, operation) + return self.__handle_column_comments(operation, m.group("table")) m = alter_table(operation) if m: @@ -156,19 +149,17 @@ def repl(m): return operation + @substitute_re_method( + r"(?<=[(,])" r"\s*`\w+`\s+\w+<\w+>\s*" r"(?=[,)])", re.IGNORECASE + ) + def __normalize_array_types(self, m): + return m.group(0).replace("<", "_").replace(">", "_") + def __handle_array_types( - self, - operation, - array_type=re.compile( - r"(?<=[(,])" r"\s*`\w+`\s+\w+<\w+>\s*" r"(?=[,)])", re.IGNORECASE - ), + self, operation, ): if self.__create_table(operation): - - def repl(m): - return m.group(0).replace("<", "_").replace(">", "_") - - return array_type.sub(repl, operation) + return self.__normalize_array_types(operation) else: return operation @@ -195,18 +186,20 @@ def __parse_dateish(type_, value): else: raise AssertionError(type_) # pragma: NO COVER + __normalize_bq_datish = substitute_re_method( + r"(?<=[[(,])\s*" + r"(?Pdate(?:time)?|time(?:stamp)?) (?P'[^']+')" + r"\s*(?=[]),])", + re.IGNORECASE, + r"parse_datish('\1', \2)", + ) + def __handle_problematic_literal_inserts( self, operation, literal_insert_values=re.compile( r"\s*(insert\s+into\s+.+\s+values\s*)" r"(\([^)]+\))" r"\s*$", re.IGNORECASE ).match, - bq_dateish=re.compile( - r"(?<=[[(,])\s*" - r"(?Pdate(?:time)?|time(?:stamp)?) (?P'[^']+')" - r"\s*(?=[]),])", - re.IGNORECASE, - ), need_to_be_pickled_literal=_need_to_be_pickled + (bytes,), ): if "?" in operation: @@ -222,7 +215,7 @@ def __handle_problematic_literal_inserts( } } - values = bq_dateish.sub(r"parse_datish('\1', \2)", values) + values = self.__normalize_bq_datish(values) values = eval(values[:-1] + ",)", safe_globals) values = ",".join( map( @@ -241,10 +234,9 @@ def __handle_problematic_literal_inserts( else: return operation - def __handle_unnest( - self, operation, unnest=re.compile(r"UNNEST\(\[ ([^\]]+)? \]\)", re.IGNORECASE), - ): - return unnest.sub(r"(\1)", operation) + __handle_unnest = substitute_re_method( + r"UNNEST\(\[ ([^\]]+)? \]\)", re.IGNORECASE, r"(\1)" + ) def __handle_true_false(self, operation): # Older sqlite versions, like those used on the CI servers @@ -264,6 +256,7 @@ def execute(self, operation, parameters=()): operation = self.__handle_problematic_literal_inserts(operation) operation = self.__handle_unnest(operation) operation = self.__handle_true_false(operation) + operation = operation.replace(" UNION DISTINCT ", " UNION ") if operation: try: @@ -306,7 +299,7 @@ def fetchone(self): return self._fix_pickled(self.cursor.fetchone()) def fetchall(self): - return map(self._fix_pickled, self.cursor) + return list(map(self._fix_pickled, self.cursor)) class attrdict(dict): diff --git a/tests/unit/test_compliance.py b/tests/unit/test_compliance.py index da2390f6..cbf40cfc 100644 --- a/tests/unit/test_compliance.py +++ b/tests/unit/test_compliance.py @@ -30,8 +30,8 @@ from conftest import setup_table, sqlalchemy_1_3_or_higher -def assert_result(connection, sel, expected): - eq_(connection.execute(sel).fetchall(), expected) +def assert_result(connection, sel, expected, params=()): + eq_(connection.execute(sel, params).fetchall(), expected) def some_table(connection): @@ -108,6 +108,19 @@ def test_percent_sign_round_trip(faux_conn, metadata): ) +@sqlalchemy_1_3_or_higher +def test_empty_set_against_integer(faux_conn): + table = some_table(faux_conn) + + stmt = ( + select([table.c.id]) + .where(table.c.x.in_(sqlalchemy.bindparam("q", expanding=True))) + .order_by(table.c.id) + ) + + assert_result(faux_conn, stmt, [], params={"q": []}) + + @sqlalchemy_1_3_or_higher def test_null_in_empty_set_is_false(faux_conn): stmt = select( diff --git a/tests/unit/test_helpers.py b/tests/unit/test_helpers.py index 1a3acc85..93965318 100644 --- a/tests/unit/test_helpers.py +++ b/tests/unit/test_helpers.py @@ -136,3 +136,59 @@ def mock_default_credentials(*args, **kwargs): ) assert bqclient.project == "connection-url-project" + + +def test_substitute_re_string(module_under_test): + import re + + foo_to_baz = module_under_test.substitute_re_method("foo", re.IGNORECASE, "baz") + assert ( + foo_to_baz(object(), "some foo and FOO is good") == "some baz and baz is good" + ) + + +def test_substitute_re_func(module_under_test): + import re + + @module_under_test.substitute_re_method("foo", re.IGNORECASE) + def Foo_to_bar(self, m): + return "bar" + + assert ( + Foo_to_bar(object(), "some foo and FOO is good") == "some bar and bar is good" + ) + + @module_under_test.substitute_re_method("foo") + def foo_to_bar(self, m, x="bar"): + return x + + assert ( + foo_to_bar(object(), "some foo and FOO is good") == "some bar and FOO is good" + ) + + assert ( + foo_to_bar(object(), "some foo and FOO is good", "hah") + == "some hah and FOO is good" + ) + + assert ( + foo_to_bar(object(), "some foo and FOO is good", x="hah") + == "some hah and FOO is good" + ) + + assert foo_to_bar.__name__ == "foo_to_bar" + + +def test_substitute_re_func_self(module_under_test): + class Replacer: + def __init__(self, x): + self.x = x + + @module_under_test.substitute_re_method("foo") + def foo_to_bar(self, m): + return self.x + + assert ( + Replacer("hah").foo_to_bar("some foo and FOO is good") + == "some hah and FOO is good" + ) diff --git a/tests/unit/test_parse_url.py b/tests/unit/test_parse_url.py index bf9f8855..3da0546d 100644 --- a/tests/unit/test_parse_url.py +++ b/tests/unit/test_parse_url.py @@ -114,8 +114,10 @@ def test_basic(url_with_everything): ], ) def test_all_values(url_with_everything, param, value, default): - url_with_this_one = make_url("bigquery://some-project/some-dataset") - url_with_this_one.query[param] = url_with_everything.query[param] + url_with_this_one = make_url( + f"bigquery://some-project/some-dataset" + f"?{param}={url_with_everything.query[param]}" + ) for url in url_with_everything, url_with_this_one: job_config = parse_url(url)[5] diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index 9cfb5b8b..45872a81 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -25,7 +25,12 @@ import pybigquery.sqlalchemy_bigquery -from conftest import setup_table, sqlalchemy_1_3_or_higher +from conftest import ( + setup_table, + sqlalchemy_1_3_or_higher, + sqlalchemy_1_4_or_higher, + sqlalchemy_before_1_4, +) def test_labels_not_forced(faux_conn): @@ -203,7 +208,20 @@ def test_disable_quote(faux_conn): assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.foo \nFROM `t`") -def test_select_in_lit(faux_conn): +def _normalize_in_params(query, params): + # We have to normalize parameter names, because they + # change with sqlalchemy versions. + newnames = sorted( + ((p, f"p_{i}") for i, p in enumerate(sorted(params))), key=lambda i: -len(i[0]) + ) + for old, new in newnames: + query = query.replace(old, new) + + return query, {new: params[old] for old, new in newnames} + + +@sqlalchemy_before_1_4 +def test_select_in_lit_13(faux_conn): [[isin]] = faux_conn.execute( sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]) ) @@ -215,6 +233,19 @@ def test_select_in_lit(faux_conn): ) +@sqlalchemy_1_4_or_higher +def test_select_in_lit(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]) + ) + assert isin + assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == ( + "SELECT %(p_0:INT64)s IN " + "UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ]) AS `anon_1`", + {"p_1": 1, "p_2": 2, "p_3": 3, "p_0": 1}, + ) + + def test_select_in_param(faux_conn): [[isin]] = faux_conn.execute( sqlalchemy.select( @@ -255,23 +286,40 @@ def test_select_in_param_empty(faux_conn): ) assert not isin assert faux_conn.test_data["execute"][-1] == ( - "SELECT %(param_1:INT64)s IN UNNEST(" "[ ]" ") AS `anon_1`", + "SELECT %(param_1:INT64)s IN(NULL) AND (1 != 1) AS `anon_1`" + if sqlalchemy.__version__ >= "1.4" + else "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1}, ) -def test_select_notin_lit(faux_conn): +@sqlalchemy_before_1_4 +def test_select_notin_lit13(faux_conn): [[isnotin]] = faux_conn.execute( sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]) ) assert isnotin assert faux_conn.test_data["execute"][-1] == ( - "SELECT %(param_1:INT64)s NOT IN " - "(%(param_2:INT64)s, %(param_3:INT64)s, %(param_4:INT64)s) AS `anon_1`", + "SELECT (%(param_1:INT64)s NOT IN " + "(%(param_2:INT64)s, %(param_3:INT64)s, %(param_4:INT64)s)) AS `anon_1`", {"param_1": 0, "param_2": 1, "param_3": 2, "param_4": 3}, ) +@sqlalchemy_1_4_or_higher +def test_select_notin_lit(faux_conn): + [[isnotin]] = faux_conn.execute( + sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]) + ) + assert isnotin + + assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == ( + "SELECT (%(p_0:INT64)s NOT IN " + "UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ])) AS `anon_1`", + {"p_0": 0, "p_1": 1, "p_2": 2, "p_3": 3}, + ) + + def test_select_notin_param(faux_conn): [[isnotin]] = faux_conn.execute( sqlalchemy.select( @@ -281,9 +329,9 @@ def test_select_notin_param(faux_conn): ) assert not isnotin assert faux_conn.test_data["execute"][-1] == ( - "SELECT %(param_1:INT64)s NOT IN UNNEST(" + "SELECT (%(param_1:INT64)s NOT IN UNNEST(" "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" - ") AS `anon_1`", + ")) AS `anon_1`", {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, ) @@ -298,6 +346,8 @@ def test_select_notin_param_empty(faux_conn): ) assert isnotin assert faux_conn.test_data["execute"][-1] == ( - "SELECT %(param_1:INT64)s NOT IN UNNEST(" "[ ]" ") AS `anon_1`", + "SELECT (%(param_1:INT64)s NOT IN(NULL) OR (1 = 1)) AS `anon_1`" + if sqlalchemy.__version__ >= "1.4" + else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1}, ) diff --git a/tests/unit/test_sqlalchemy_bigquery.py b/tests/unit/test_sqlalchemy_bigquery.py index dc65d513..2cad9c82 100644 --- a/tests/unit/test_sqlalchemy_bigquery.py +++ b/tests/unit/test_sqlalchemy_bigquery.py @@ -137,3 +137,24 @@ def test_get_view_names( mock_bigquery_client.list_datasets.assert_called_once() assert mock_bigquery_client.list_tables.call_count == len(datasets_list) assert list(sorted(view_names)) == list(sorted(expected)) + + +@pytest.mark.parametrize( + "inp, outp", + [ + ("(NULL IN UNNEST([ NULL) AND (1 != 1 ]))", "(NULL IN(NULL) AND (1 != 1))"), + ( + "(NULL IN UNNEST([ NULL) AND (1 != 1:INT64 ]))", + "(NULL IN(NULL) AND (1 != 1))", + ), + ( + "(NULL IN UNNEST([ (NULL, NULL)) AND (1 != 1:INT64 ]))", + "(NULL IN((NULL, NULL)) AND (1 != 1))", + ), + ], +) +def test__remove_type_from_empty_in(inp, outp): + from pybigquery.sqlalchemy_bigquery import BigQueryExecutionContext + + r = BigQueryExecutionContext._BigQueryExecutionContext__remove_type_from_empty_in + assert r(None, inp) == outp