From fb7c188fd1d61f2bb2b99742f62042576bff02a9 Mon Sep 17 00:00:00 2001 From: Jim Fulton Date: Wed, 12 May 2021 06:51:02 -0600 Subject: [PATCH] feat: Comment/description support, bug fixes and better test coverage (#138) - Runs SQLAlchemy dialect-compliance tests (as system tests). - 100% unit-test coverage. - Support for table and column comments/descriptions (requiring SQLAlchemy 1.2 or higher). - Fixes bugs found while debugging tests, including: - Handling of `in` queries. - String literals with special characters. - Use BIGNUMERIC when necessary. - Missing types: BIGINT, SMALLINT, Boolean, REAL, CHAR, NCHAR, VARCHAR, NVARCHAR, TEXT, VARBINARY, DECIMAL - Literal bytes, dates, times, datetimes, timestamps, and arrays. - Get view definitions. - When executing parameterized queries, the new BigQuery DB API parameter syntax is used to pass type information. This is helpful when the DB API can't determine type information from values, or can't determine it correctly. --- .coveragerc | 4 +- noxfile.py | 61 ++- pybigquery/requirements.py | 220 +++++++++ pybigquery/sqlalchemy_bigquery.py | 382 ++++++++++++--- setup.cfg | 9 +- setup.py | 8 +- testing/constraints-3.6.txt | 7 +- .../sqlalchemy_dialect_compliance/README.rst | 12 + .../sqlalchemy_dialect_compliance/conftest.py | 68 +++ .../test_dialect_compliance.py | 149 ++++++ tests/unit/conftest.py | 77 ++- tests/unit/fauxdbi.py | 444 +++++++++++++++--- tests/unit/test_api.py | 37 ++ tests/unit/test_catalog_functions.py | 259 ++++++++++ tests/unit/test_comments.py | 100 ++++ tests/unit/test_compiler.py | 55 +++ tests/unit/test_compliance.py | 189 ++++++++ tests/unit/test_engine.py | 54 +++ tests/unit/test_like_reescape.py | 43 ++ tests/unit/test_parse_url.py | 60 ++- tests/unit/test_select.py | 302 +++++++++++- tests/unit/test_view.py | 35 ++ 22 files changed, 2414 insertions(+), 161 deletions(-) create mode 100644 pybigquery/requirements.py create mode 100644 tests/sqlalchemy_dialect_compliance/README.rst create mode 100644 tests/sqlalchemy_dialect_compliance/conftest.py create mode 100644 tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py create mode 100644 tests/unit/test_api.py create mode 100644 tests/unit/test_catalog_functions.py create mode 100644 tests/unit/test_comments.py create mode 100644 tests/unit/test_compiler.py create mode 100644 tests/unit/test_compliance.py create mode 100644 tests/unit/test_engine.py create mode 100644 tests/unit/test_like_reescape.py create mode 100644 tests/unit/test_view.py diff --git a/.coveragerc b/.coveragerc index 0d8e6297..d5e3f1dc 100644 --- a/.coveragerc +++ b/.coveragerc @@ -17,8 +17,6 @@ # Generated by synthtool. DO NOT EDIT! [run] branch = True -omit = - google/cloud/__init__.py [report] fail_under = 100 @@ -35,4 +33,4 @@ omit = */proto/*.py */core/*.py */site-packages/*.py - google/cloud/__init__.py + pybigquery/requirements.py diff --git a/noxfile.py b/noxfile.py index 3ccaff8a..ec7c1e7e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,17 +28,18 @@ BLACK_PATHS = ["docs", "pybigquery", "tests", "noxfile.py", "setup.py"] DEFAULT_PYTHON_VERSION = "3.8" -SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] +SYSTEM_TEST_PYTHON_VERSIONS = ["3.9"] UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() # 'docfx' is excluded since it only needs to run in 'docs-presubmit' nox.options.sessions = [ + "lint", "unit", - "system", "cover", - "lint", + "system", + "compliance", "lint_setup_py", "blacken", "docs", @@ -169,6 +170,58 @@ def system(session): ) +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def compliance(session): + """Run the system test suite.""" + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + system_test_folder_path = os.path.join("tests", "sqlalchemy_dialect_compliance") + + # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. + if os.environ.get("RUN_COMPLIANCE_TESTS", "true") == "false": + session.skip("RUN_COMPLIANCE_TESTS is set to false, skipping") + # Sanity check: Only run tests if the environment variable is set. + if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): + session.skip("Credentials must be set via environment variable") + # Install pyopenssl for mTLS testing. + if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": + session.install("pyopenssl") + # Sanity check: only run tests if found. + if not os.path.exists(system_test_folder_path): + session.skip("Compliance tests were not found") + + # Use pre-release gRPC for system tests. + session.install("--pre", "grpcio") + + # Install all test dependencies, then install this package into the + # virtualenv's dist-packages. + session.install( + "mock", + "pytest", + "pytest-rerunfailures", + "google-cloud-testutils", + "-c", + constraints_path, + ) + session.install("-e", ".", "-c", constraints_path) + + session.run( + "py.test", + "-vv", + f"--junitxml=compliance_{session.python}_sponge_log.xml", + "--reruns=3", + "--reruns-delay=60", + "--only-rerun=" + "403 Exceeded rate limits|" + "409 Already Exists|" + "404 Not found|" + "400 Cannot execute DML over a non-existent table", + system_test_folder_path, + *session.posargs, + ) + + @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. @@ -177,7 +230,7 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=50") + session.run("coverage", "report", "--show-missing", "--fail-under=100") session.run("coverage", "erase") diff --git a/pybigquery/requirements.py b/pybigquery/requirements.py new file mode 100644 index 00000000..77726faf --- /dev/null +++ b/pybigquery/requirements.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +This module is used by the compliance tests to control which tests are run + +based on database capabilities. +""" + +import sqlalchemy.testing.requirements +import sqlalchemy.testing.exclusions + +supported = sqlalchemy.testing.exclusions.open +unsupported = sqlalchemy.testing.exclusions.closed + + +class Requirements(sqlalchemy.testing.requirements.SuiteRequirements): + @property + def index_reflection(self): + return unsupported() + + @property + def indexes_with_ascdesc(self): + """target database supports CREATE INDEX with per-column ASC/DESC.""" + return unsupported() + + @property + def unique_constraint_reflection(self): + """target dialect supports reflection of unique constraints""" + return unsupported() + + @property + def autoincrement_insert(self): + """target platform generates new surrogate integer primary key values + when insert() is executed, excluding the pk column.""" + return unsupported() + + @property + def primary_key_constraint_reflection(self): + return unsupported() + + @property + def foreign_keys(self): + """Target database must support foreign keys.""" + + return unsupported() + + @property + def foreign_key_constraint_reflection(self): + return unsupported() + + @property + def on_update_cascade(self): + """target database must support ON UPDATE..CASCADE behavior in + foreign keys.""" + + return unsupported() + + @property + def named_constraints(self): + """target database must support names for constraints.""" + + return unsupported() + + @property + def temp_table_reflection(self): + return unsupported() + + @property + def temporary_tables(self): + """target database supports temporary tables""" + return unsupported() # Temporary tables require use of scripts. + + @property + def duplicate_key_raises_integrity_error(self): + """target dialect raises IntegrityError when reporting an INSERT + with a primary key violation. (hint: it should) + + """ + return unsupported() + + @property + def precision_numerics_many_significant_digits(self): + """target backend supports values with many digits on both sides, + such as 319438950232418390.273596, 87673.594069654243 + + """ + return supported() + + @property + def date_coerces_from_datetime(self): + """target dialect accepts a datetime object as the target + of a date column.""" + + # BigQuery doesn't allow saving a datetime in a date: + # `TYPE_DATE`, Invalid date: '2012-10-15T12:57:18' + + return unsupported() + + @property + def window_functions(self): + """Target database must support window functions.""" + return supported() # There are no tests for this. + + @property + def ctes(self): + """Target database supports CTEs""" + + return supported() + + @property + def views(self): + """Target database must support VIEWs.""" + + return supported() + + @property + def schemas(self): + """Target database must support external schemas, and have one + named 'test_schema'.""" + + return supported() + + @property + def implicit_default_schema(self): + """target system has a strong concept of 'default' schema that can + be referred to implicitly. + + basically, PostgreSQL. + + """ + return supported() + + @property + def comment_reflection(self): + return supported() # Well, probably not, but we'll try. :) + + @property + def unicode_ddl(self): + """Target driver must support some degree of non-ascii symbol + names. + """ + return supported() + + @property + def datetime_literals(self): + """target dialect supports rendering of a date, time, or datetime as a + literal string, e.g. via the TypeEngine.literal_processor() method. + + """ + + return supported() + + @property + def timestamp_microseconds(self): + """target dialect supports representation of Python + datetime.datetime() with microsecond objects but only + if TIMESTAMP is used.""" + return supported() + + @property + def datetime_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return supported() + + @property + def date_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return supported() + + @property + def precision_numerics_enotation_small(self): + """target backend supports Decimal() objects using E notation + to represent very small values.""" + return supported() + + @property + def precision_numerics_enotation_large(self): + """target backend supports Decimal() objects using E notation + to represent very large values.""" + return supported() + + @property + def update_from(self): + """Target must support UPDATE..FROM syntax""" + return supported() + + @property + def order_by_label_with_expression(self): + """target backend supports ORDER BY a column label within an + expression. + + Basically this:: + + select data as foo from test order by foo || 'bar' + + Lots of databases including PostgreSQL don't support this, + so this is off by default. + + """ + return supported() diff --git a/pybigquery/sqlalchemy_bigquery.py b/pybigquery/sqlalchemy_bigquery.py index 5a6ad105..764c3fc0 100644 --- a/pybigquery/sqlalchemy_bigquery.py +++ b/pybigquery/sqlalchemy_bigquery.py @@ -22,7 +22,10 @@ from __future__ import absolute_import from __future__ import unicode_literals +from decimal import Decimal +import random import operator +import uuid from google import auth import google.api_core.exceptions @@ -30,6 +33,9 @@ from google.cloud.bigquery.schema import SchemaField from google.cloud.bigquery.table import TableReference from google.api_core.exceptions import NotFound + +import sqlalchemy.sql.sqltypes +import sqlalchemy.sql.type_api from sqlalchemy.exc import NoSuchTableError from sqlalchemy import types, util from sqlalchemy.sql.compiler import ( @@ -38,10 +44,11 @@ DDLCompiler, IdentifierPreparer, ) +from sqlalchemy.sql.sqltypes import Integer, String, NullType, Numeric from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.engine.base import Engine from sqlalchemy.sql.schema import Column -from sqlalchemy.sql import elements +from sqlalchemy.sql import elements, selectable import re from .parse_url import parse_url @@ -50,24 +57,12 @@ FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+") -class UniversalSet(object): - """ - Set containing everything - https://github.com/dropbox/PyHive/blob/master/pyhive/common.py - """ - - def __contains__(self, item): - return True - - class BigQueryIdentifierPreparer(IdentifierPreparer): """ Set containing everything https://github.com/dropbox/PyHive/blob/master/pyhive/sqlalchemy_presto.py """ - reserved_words = UniversalSet() - def __init__(self, dialect): super(BigQueryIdentifierPreparer, self).__init__( dialect, initial_quote="`", @@ -88,21 +83,7 @@ def quote(self, ident, force=None, column=False): """ force = getattr(ident, "quote", None) - - if force is None: - if ident in self._strings: - return self._strings[ident] - else: - if self._requires_quotes(ident): - self._strings[ident] = ( - self.quote_column(ident) - if column - else self.quote_identifier(ident) - ) - else: - self._strings[ident] = ident - return self._strings[ident] - elif force: + if force is None or force: return self.quote_column(ident) if column else self.quote_identifier(ident) else: return ident @@ -123,8 +104,11 @@ def format_label(self, label, name=None): _type_map = { "STRING": types.String, + "BOOL": types.Boolean, "BOOLEAN": types.Boolean, + "INT64": types.Integer, "INTEGER": types.Integer, + "FLOAT64": types.Float, "FLOAT": types.Float, "TIMESTAMP": types.TIMESTAMP, "DATETIME": types.DATETIME, @@ -133,11 +117,15 @@ def format_label(self, label, name=None): "TIME": types.TIME, "RECORD": types.JSON, "NUMERIC": types.DECIMAL, + "BIGNUMERIC": types.DECIMAL, } STRING = _type_map["STRING"] +BOOL = _type_map["BOOL"] BOOLEAN = _type_map["BOOLEAN"] +INT64 = _type_map["INT64"] INTEGER = _type_map["INTEGER"] +FLOAT64 = _type_map["FLOAT64"] FLOAT = _type_map["FLOAT"] TIMESTAMP = _type_map["TIMESTAMP"] DATETIME = _type_map["DATETIME"] @@ -146,6 +134,7 @@ def format_label(self, label, name=None): TIME = _type_map["TIME"] RECORD = _type_map["RECORD"] NUMERIC = _type_map["NUMERIC"] +BIGNUMERIC = _type_map["NUMERIC"] class BigQueryExecutionContext(DefaultExecutionContext): @@ -156,8 +145,49 @@ def create_cursor(self): c.arraysize = self.dialect.arraysize return c + def get_insert_default(self, column): # pragma: NO COVER + # Only used by compliance tests + if isinstance(column.type, Integer): + return random.randint(-9223372036854775808, 9223372036854775808) # 1<<63 + elif isinstance(column.type, String): + return str(uuid.uuid4()) + + def pre_exec( + self, + in_sub=re.compile( + r" IN UNNEST\(\[ " + r"(%\([^)]+_\d+\)s(?:, %\([^)]+_\d+\)s)*)?" # Placeholders. See below. + r":([A-Z0-9]+)" # Type + r" \]\)" + ).sub, + ): + # If we have an in parameter, it sometimes gets expaned to 0 or more + # parameters and we need to move the type marker to each + # parameter. + # (The way SQLAlchemy handles this is a bit awkward for our + # purposes.) + + # In the placeholder part of the regex above, the `_\d+ + # suffixes refect that when an array parameter is expanded, + # numeric suffixes are added. For example, a placeholder like + # `%(foo)s` gets expaneded to `%(foo_0)s, `%(foo_1)s, ...`. + + def repl(m): + placeholders, type_ = m.groups() + if placeholders: + placeholders = placeholders.replace(")", f":{type_})") + else: + placeholders = "" + return f" IN UNNEST([ {placeholders} ])" + + self.statement = in_sub(repl, self.statement) + class BigQueryCompiler(SQLCompiler): + + compound_keywords = SQLCompiler.compound_keywords.copy() + compound_keywords[selectable.CompoundSelect.UNION] = "UNION ALL" + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): if isinstance(statement, Column): kwargs["compile_kwargs"] = util.immutabledict({"include_table": False}) @@ -165,10 +195,23 @@ def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs) dialect, statement, column_keys, inline, **kwargs ) + def visit_insert(self, insert_stmt, asfrom=False, **kw): + # The (internal) documentation for `inline` is confusing, but + # having `inline` be true prevents us from generating default + # primary-key values when we're doing executemany, which seem broken. + + # We can probably do this in the constructor, but I want to + # make sure this only affects insert, because I'm paranoid. :) + + self.inline = False + + return super(BigQueryCompiler, self).visit_insert( + insert_stmt, asfrom=False, **kw + ) + def visit_column( self, column, add_to_result_map=None, include_table=True, **kwargs ): - name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -188,16 +231,10 @@ def visit_column( if table is None or not include_table or not table.named_with_column: return name else: - effective_schema = self.preparer.schema_for_object(table) - - if effective_schema: - schema_prefix = self.preparer.quote_schema(effective_schema) + "." - else: - schema_prefix = "" tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) - return schema_prefix + self.preparer.quote(tablename) + "." + name + return self.preparer.quote(tablename) + "." + name def visit_label(self, *args, within_group_by=False, **kwargs): # Use labels in GROUP BY clause. @@ -213,31 +250,162 @@ def group_by_clause(self, select, **kw): select, **kw, within_group_by=True ) + ############################################################################ + # Handle parameters in in + + # Due to details in the way sqlalchemy arranges the compilation we + # expect the bind parameter as an array and unnest it. + + # As it happens, bigquery can handle arrays directly, but there's + # no way to tell sqlalchemy that, so it works harder than + # necessary and makes us do the same. + + _in_expanding_bind = re.compile(r" IN \((\[EXPANDING_\w+\](:[A-Z0-9]+)?)\)$") + + def _unnestify_in_expanding_bind(self, in_text): + return self._in_expanding_bind.sub(r" IN UNNEST([ \1 ])", in_text) + + def visit_in_op_binary(self, binary, operator_, **kw): + return self._unnestify_in_expanding_bind( + self._generate_generic_binary(binary, " IN ", **kw) + ) + + def visit_empty_set_expr(self, element_types): + return "" + + def visit_notin_op_binary(self, binary, operator, **kw): + return self._unnestify_in_expanding_bind( + self._generate_generic_binary(binary, " NOT IN ", **kw) + ) + + ############################################################################ + + ############################################################################ + # Correct for differences in the way that SQLAlchemy escape % and _ (/) + # and BigQuery does (\\). + + @staticmethod + def _maybe_reescape(binary): + binary = binary._clone() + escape = binary.modifiers.pop("escape", None) + if escape and escape != "\\": + binary.right.value = escape.join( + v.replace(escape, "\\") + for v in binary.right.value.split(escape + escape) + ) + return binary + + def visit_contains_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_contains_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_notcontains_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_notcontains_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_startswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_startswith_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_notstartswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_notstartswith_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_endswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_endswith_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_notendswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_notendswith_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + ############################################################################ + + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + **kwargs, + ): + param = super(BigQueryCompiler, self).visit_bindparam( + bindparam, + within_columns_clause, + literal_binds, + skip_bind_expression, + **kwargs, + ) + + type_ = bindparam.type + if isinstance(type_, NullType): + return param + + if ( + isinstance(type_, Numeric) + and (type_.precision is None or type_.scale is None) + and isinstance(bindparam.value, Decimal) + ): + t = bindparam.value.as_tuple() + + if type_.precision is None: + type_.precision = len(t.digits) + + if type_.scale is None and t.exponent < 0: + type_.scale = -t.exponent + + bq_type = self.dialect.type_compiler.process(type_) + if bq_type[-1] == ">" and bq_type.startswith("ARRAY<"): + # Values get arrayified at a lower level. + bq_type = bq_type[6:-1] + + assert param != "%s" + return param.replace(")", f":{bq_type})") + class BigQueryTypeCompiler(GenericTypeCompiler): - def visit_integer(self, type_, **kw): + def visit_INTEGER(self, type_, **kw): return "INT64" - def visit_float(self, type_, **kw): + visit_BIGINT = visit_SMALLINT = visit_INTEGER + + def visit_BOOLEAN(self, type_, **kw): + return "BOOL" + + def visit_FLOAT(self, type_, **kw): return "FLOAT64" - def visit_text(self, type_, **kw): - return "STRING" + visit_REAL = visit_FLOAT - def visit_string(self, type_, **kw): + def visit_STRING(self, type_, **kw): return "STRING" + visit_CHAR = visit_NCHAR = visit_STRING + visit_VARCHAR = visit_NVARCHAR = visit_TEXT = visit_STRING + def visit_ARRAY(self, type_, **kw): return "ARRAY<{}>".format(self.process(type_.item_type, **kw)) def visit_BINARY(self, type_, **kw): return "BYTES" + visit_VARBINARY = visit_BINARY + def visit_NUMERIC(self, type_, **kw): - return "NUMERIC" + if (type_.precision is not None and type_.precision > 38) or ( + type_.scale is not None and type_.scale > 9 + ): + return "BIGNUMERIC" + else: + return "NUMERIC" - def visit_DECIMAL(self, type_, **kw): - return "NUMERIC" + visit_DECIMAL = visit_NUMERIC class BigQueryDDLCompiler(DDLCompiler): @@ -250,31 +418,110 @@ def visit_foreign_key_constraint(self, constraint): def visit_primary_key_constraint(self, constraint): return None + # BigQuery has no support for unique constraints. + def visit_unique_constraint(self, constraint): + return None + def get_column_specification(self, column, **kwargs): colspec = super(BigQueryDDLCompiler, self).get_column_specification( column, **kwargs ) - if column.doc is not None: + if column.comment is not None: colspec = "{} OPTIONS(description={})".format( - colspec, self.preparer.quote(column.doc) + colspec, process_string_literal(column.comment) ) return colspec def post_create_table(self, table): bq_opts = table.dialect_options["bigquery"] opts = [] - if "description" in bq_opts: - opts.append( - "description={}".format(self.preparer.quote(bq_opts["description"])) + + if ("description" in bq_opts) or table.comment: + description = process_string_literal( + bq_opts.get("description", table.comment) ) + opts.append(f"description={description}") + if "friendly_name" in bq_opts: opts.append( - "friendly_name={}".format(self.preparer.quote(bq_opts["friendly_name"])) + "friendly_name={}".format( + process_string_literal(bq_opts["friendly_name"]) + ) ) + if opts: return "\nOPTIONS({})".format(", ".join(opts)) + return "" + def visit_set_table_comment(self, create): + table_name = self.preparer.format_table(create.element) + description = self.sql_compiler.render_literal_value( + create.element.comment, sqlalchemy.sql.sqltypes.String() + ) + return f"ALTER TABLE {table_name} SET OPTIONS(description={description})" + + def visit_drop_table_comment(self, drop): + table_name = self.preparer.format_table(drop.element) + return f"ALTER TABLE {table_name} SET OPTIONS(description=null)" + + +def process_string_literal(value): + return repr(value.replace("%", "%%")) + + +class BQString(String): + def literal_processor(self, dialect): + return process_string_literal + + +class BQBinary(sqlalchemy.sql.sqltypes._Binary): + @staticmethod + def __process_bytes_literal(value): + return repr(value.replace(b"%", b"%%")) + + def literal_processor(self, dialect): + return self.__process_bytes_literal + + +class BQClassTaggedStr(sqlalchemy.sql.type_api.TypeEngine): + """Type that can get literals via str + """ + + @staticmethod + def process_literal_as_class_tagged_str(value): + return f"{value.__class__.__name__.upper()} {repr(str(value))}" + + def literal_processor(self, dialect): + return self.process_literal_as_class_tagged_str + + +class BQTimestamp(sqlalchemy.sql.type_api.TypeEngine): + """Type that can get literals via str + """ + + @staticmethod + def process_timestamp_literal(value): + return f"TIMESTAMP {process_string_literal(str(value))}" + + def literal_processor(self, dialect): + return self.process_timestamp_literal + + +class BQArray(sqlalchemy.sql.sqltypes.ARRAY): + def literal_processor(self, dialect): + + item_processor = self.item_type._cached_literal_processor(dialect) + if not item_processor: + raise NotImplementedError( + f"Don't know how to literal-quote values of type {self.item_type}" + ) + + def process_array_literal(value): + return "[" + ", ".join(item_processor(v) for v in value) + "]" + + return process_array_literal + class BigQueryDialect(DefaultDialect): name = "bigquery" @@ -285,6 +532,8 @@ class BigQueryDialect(DefaultDialect): ddl_compiler = BigQueryDDLCompiler execution_ctx_cls = BigQueryExecutionContext supports_alter = False + supports_comments = True + inline_comments = True supports_pk_autoincrement = False supports_default_values = False supports_empty_insert = False @@ -297,6 +546,17 @@ class BigQueryDialect(DefaultDialect): supports_native_boolean = True supports_simple_order_by_label = True postfetch_lastrowid = False + preexecute_autoincrement_sequences = False + + colspecs = { + String: BQString, + sqlalchemy.sql.sqltypes._Binary: BQBinary, + sqlalchemy.sql.sqltypes.Date: BQClassTaggedStr, + sqlalchemy.sql.sqltypes.DateTime: BQClassTaggedStr, + sqlalchemy.sql.sqltypes.Time: BQClassTaggedStr, + sqlalchemy.sql.sqltypes.TIMESTAMP: BQTimestamp, + sqlalchemy.sql.sqltypes.ARRAY: BQArray, + } def __init__( self, @@ -305,7 +565,7 @@ def __init__( location=None, credentials_info=None, *args, - **kwargs + **kwargs, ): super(BigQueryDialect, self).__init__(*args, **kwargs) self.arraysize = arraysize @@ -427,9 +687,7 @@ def _table_reference( dataset_id_from_schema = None if provided_schema_name is not None: provided_schema_name_split = provided_schema_name.split(".") - if len(provided_schema_name_split) == 0: - pass - elif len(provided_schema_name_split) == 1: + if len(provided_schema_name_split) == 1: if dataset_id_from_table: project_id_from_schema = provided_schema_name_split[0] else: @@ -579,10 +837,7 @@ def get_schema_names(self, connection, **kw): connection = connection.connect() datasets = connection.connection._client.list_datasets() - if self.dataset_id is not None: - return [d.dataset_id for d in datasets if d.dataset_id == self.dataset_id] - else: - return [d.dataset_id for d in datasets] + return [d.dataset_id for d in datasets] def get_table_names(self, connection, schema=None, **kw): if isinstance(connection, Engine): @@ -604,6 +859,11 @@ def _check_unicode_returns(self, connection, additional_tests=None): # requests gives back Unicode strings return True - def _check_unicode_description(self, connection): - # requests gives back Unicode strings - return True + def get_view_definition(self, connection, view_name, schema=None, **kw): + if isinstance(connection, Engine): + connection = connection.connect() + client = connection.connection._client + if self.dataset_id: + view_name = f"{self.dataset_id}.{view_name}" + view = client.get_table(view_name) + return view.view_query diff --git a/setup.cfg b/setup.cfg index 95ac0e28..91fcadc7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,8 +20,11 @@ universal = 1 [sqla_testing] requirement_cls=pybigquery.requirements:Requirements -profile_file=.profiles.txt +profile_file=sqlalchemy_dialect_compliance-profiles.txt [db] -default=bigquery:// -bigquery=bigquery:// +default=bigquery:///test_pybigquery_sqla + +[tool:pytest] +addopts= --tb native -v -r fxX -p no:warnings +python_files=tests/*test_*.py diff --git a/setup.py b/setup.py index 97475956..d93f2225 100644 --- a/setup.py +++ b/setup.py @@ -65,10 +65,10 @@ def readme(): ], platforms="Posix; MacOS X; Windows", install_requires=[ - "sqlalchemy>=1.1.9,<1.4.0dev", - "google-auth>=1.14.0,<2.0dev", # Work around pip wack. - "google-cloud-bigquery>=1.12.0", - "google-api-core>=1.19.1", # Work-around bug in cloud core deps. + "sqlalchemy>=1.2.0,<1.4.0dev", + "google-auth>=1.24.0,<2.0dev", # Work around pip wack. + "google-cloud-bigquery>=2.15.0", + "google-api-core>=1.23.0", # Work-around bug in cloud core deps. "future", ], python_requires=">=3.6, <3.10", diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 34cbdb7a..5bc8ccf5 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -4,6 +4,7 @@ # Pin the version to the lower bound. # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", -sqlalchemy==1.1.9 -google-auth==1.14.0 -google-cloud-bigquery==1.12.0 +sqlalchemy==1.2.0 +google-auth==1.24.0 +google-cloud-bigquery==2.15.0 +google-api-core==1.23.0 diff --git a/tests/sqlalchemy_dialect_compliance/README.rst b/tests/sqlalchemy_dialect_compliance/README.rst new file mode 100644 index 00000000..7947ec26 --- /dev/null +++ b/tests/sqlalchemy_dialect_compliance/README.rst @@ -0,0 +1,12 @@ +SQLAlchemy Dialog Compliance Tests +================================== + +SQLAlchemy provides reusable tests that test that SQLAlchemy dialects +work properly. This directory applies these tests to the BigQuery +SQLAlchemy dialect. + +These are "system" tests, meaning that they run against a real +BigQuery account. To run the tests, you need a BigQuery account with +empty `test_pybigquery_sqla` and `test_schema` schemas. You need to +have the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to +the path of a Google Cloud authentication file. diff --git a/tests/sqlalchemy_dialect_compliance/conftest.py b/tests/sqlalchemy_dialect_compliance/conftest.py new file mode 100644 index 00000000..eefd3f07 --- /dev/null +++ b/tests/sqlalchemy_dialect_compliance/conftest.py @@ -0,0 +1,68 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from sqlalchemy.testing.plugin.pytestplugin import * # noqa +from sqlalchemy.testing.plugin.pytestplugin import ( + pytest_sessionstart as _pytest_sessionstart, +) + +import google.cloud.bigquery.dbapi.connection +import pybigquery.sqlalchemy_bigquery +import sqlalchemy +import traceback + +pybigquery.sqlalchemy_bigquery.BigQueryDialect.preexecute_autoincrement_sequences = True +google.cloud.bigquery.dbapi.connection.Connection.rollback = lambda self: None + + +# BigQuery requires delete statements to have where clauses. Other +# databases don't and sqlalchemy doesn't include where clauses when +# cleaning up test data. So we add one when we see a delete without a +# where clause when tearing down tests. We only do this during tear +# down, by inspecting the stack, because we don't want to hide bugs +# outside of test house-keeping. +def visit_delete(self, delete_stmt, *args, **kw): + if delete_stmt._whereclause is None and "teardown" in set( + f.name for f in traceback.extract_stack() + ): + delete_stmt._whereclause = sqlalchemy.true() + + return super(pybigquery.sqlalchemy_bigquery.BigQueryCompiler, self).visit_delete( + delete_stmt, *args, **kw + ) + + +pybigquery.sqlalchemy_bigquery.BigQueryCompiler.visit_delete = visit_delete + + +# Clean up test schemas so we don't get spurious errors when the tests +# try to create tables that already exist. +def pytest_sessionstart(session): + client = google.cloud.bigquery.Client() + for schema in "test_schema", "test_pybigquery_sqla": + for table_item in client.list_tables(f"{client.project}.{schema}"): + table_id = table_item.table_id + list( + client.query( + f"drop {'view' if table_id.endswith('_v') else 'table'}" + f" {schema}.{table_id}" + ).result() + ) + client.close() + _pytest_sessionstart(session) diff --git a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py new file mode 100644 index 00000000..259a78ec --- /dev/null +++ b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import datetime +import mock +import pytest +import pytz +from sqlalchemy import and_ +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.suite import config, select, exists +from sqlalchemy.testing.suite import * # noqa +from sqlalchemy.testing.suite import ( + ComponentReflectionTest as _ComponentReflectionTest, + CTETest as _CTETest, + ExistsTest as _ExistsTest, + InsertBehaviorTest as _InsertBehaviorTest, + LimitOffsetTest as _LimitOffsetTest, + LongNameBlowoutTest, + QuotedNameArgumentTest, + SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest, + TimestampMicrosecondsTest as _TimestampMicrosecondsTest, +) + +# Quotes aren't allowed in BigQuery table names. +del QuotedNameArgumentTest + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip() + def test_insert_from_select_autoinc(cls): + """BQ has no autoinc and client-side defaults can't work for select.""" + + +class ExistsTest(_ExistsTest): + """ + Override + + Becaise Bigquery requires FROM when there's a WHERE and + the base tests didn't do provide a FROM. + """ + + def test_select_exists(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select([stuff.c.id]).where( + and_(stuff.c.id == 1, exists().where(stuff.c.data == "some data"),) + ) + ).fetchall(), + [(1,)], + ) + + def test_select_exists_false(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select([stuff.c.id]).where(exists().where(stuff.c.data == "no data")) + ).fetchall(), + [], + ) + + +class LimitOffsetTest(_LimitOffsetTest): + @pytest.mark.skip() + def test_simple_offset(self): + """BigQuery doesn't allow an offset without a limit.""" + + test_bound_offset = test_simple_offset + + +# This test requires features (indexes, primary keys, etc., that BigQuery doesn't have. +del LongNameBlowoutTest + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + """The base tests fail if operations return rows for some reason.""" + + def test_update(self): + t = self.tables.plain_pk + r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new") + assert not r.is_insert + # assert not r.returns_rows + + eq_( + config.db.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self): + t = self.tables.plain_pk + r = config.db.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + # assert not r.returns_rows + eq_( + config.db.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +class CTETest(_CTETest): + @pytest.mark.skip("Can't use CTEs with insert") + def test_insert_from_select_round_trip(self): + pass + + @pytest.mark.skip("Recusive CTEs aren't supported.") + def test_select_recursive_round_trip(self): + pass + + +class ComponentReflectionTest(_ComponentReflectionTest): + @pytest.mark.skip("Big query types don't track precision, length, etc.") + def course_grained_types(): + pass + + test_numeric_reflection = test_varchar_reflection = course_grained_types + + +class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): + + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC) + + def test_literal(self): + # The base tests doesn't set up the literal properly, because + # it doesn't pass its datatype to `literal`. + + def literal(value): + assert value == self.data + import sqlalchemy.sql.sqltypes + + return sqlalchemy.sql.elements.literal(value, self.datatype) + + with mock.patch("sqlalchemy.testing.suite.test_types.literal", literal): + super(TimestampMicrosecondsTest, self).test_literal() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 22def748..801e84a9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,16 +1,79 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import contextlib import mock +import sqlite3 + import pytest import sqlalchemy import fauxdbi +sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split("."))) +sqlalchemy_1_3_or_higher = pytest.mark.skipif( + sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher" +) + @pytest.fixture() def faux_conn(): - with mock.patch( - "google.cloud.bigquery.dbapi.connection.Connection", fauxdbi.Connection - ): - engine = sqlalchemy.create_engine("bigquery://myproject/mydataset") - conn = engine.connect() - yield conn - conn.close() + test_data = dict(execute=[]) + connection = sqlite3.connect(":memory:") + + def factory(*args, **kw): + conn = fauxdbi.Connection(connection, test_data, *args, **kw) + return conn + + with mock.patch("google.cloud.bigquery.dbapi.connection.Connection", factory): + # We want to bypass client creation. We don't need it and it requires creds. + with mock.patch( + "pybigquery._helpers.create_bigquery_client", fauxdbi.FauxClient + ): + with mock.patch("google.auth.default", return_value=("authdb", "authproj")): + engine = sqlalchemy.create_engine("bigquery://myproject/mydataset") + conn = engine.connect() + conn.test_data = test_data + + def ex(sql, *args, **kw): + with contextlib.closing( + conn.connection.connection.connection.cursor() + ) as cursor: + cursor.execute(sql, *args, **kw) + + conn.ex = ex + + ex("create table comments" " (key string primary key, comment string)") + + yield conn + conn.close() + + +@pytest.fixture() +def metadata(): + return sqlalchemy.MetaData() + + +def setup_table(connection, name, *columns, initial_data=(), **kw): + metadata = sqlalchemy.MetaData() + table = sqlalchemy.Table(name, metadata, *columns, **kw) + metadata.create_all(connection.engine) + if initial_data: + connection.execute(table.insert(), initial_data) + return table diff --git a/tests/unit/fauxdbi.py b/tests/unit/fauxdbi.py index 44c4edae..70cbb8aa 100644 --- a/tests/unit/fauxdbi.py +++ b/tests/unit/fauxdbi.py @@ -1,49 +1,280 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import base64 +import contextlib +import datetime +import decimal +import pickle +import re +import sqlite3 + import google.api_core.exceptions import google.cloud.bigquery.schema import google.cloud.bigquery.table -import contextlib -import sqlite3 +import google.cloud.bigquery.dbapi.cursor class Connection: - - connection = None - - def __init__(self, client=None, bqstorage_client=None): - # share a single connection: - if self.connection is None: - self.__class__.connection = sqlite3.connect(":memory:") - self._client = FauxClient(client, self.connection) + def __init__(self, connection, test_data, client, *args, **kw): + self.connection = connection + self.test_data = test_data + self._client = client + client.connection = self def cursor(self): - return Cursor(self.connection) + return Cursor(self) def commit(self): pass - def rollback(self): - pass - def close(self): - self.connection.close() +class Cursor: + def __init__(self, connection): + self.connection = connection + self.cursor = connection.connection.cursor() + assert self.arraysize == 1 + __arraysize = 1 -class Cursor: + @property + def arraysize(self): + return self.__arraysize - arraysize = 1 + @arraysize.setter + def arraysize(self, v): + self.__arraysize = v + self.connection.test_data["arraysize"] = v - def __init__(self, connection): - self.connection = connection - self.cursor = connection.cursor() + # A Note on the use of pickle here + # ================================ + # + # BigQuery supports types that sqlite doesn't. We compensate by + # pickling unhandled types and saving the pickles as + # strings. Bonus: literals require extra handling. + # + # Note that this only needs to be robust enough for tests. :) So + # when reading data, we simply look for pickle protocol 4 + # prefixes, because we don't have to worry about people providing + # non-pickle string values with those prefixes, because we control + # the inputs in the tests and we choose not to do that. - def execute(self, operation, parameters=None): - if parameters: - parameters = { - name: "null" if value is None else repr(value) - for name, value in parameters.items() + _need_to_be_pickled = ( + list, + dict, + decimal.Decimal, + bool, + datetime.datetime, + datetime.date, + datetime.time, + ) + + def __convert_params( + self, + operation, + parameters, + placeholder=re.compile(r"%\((\w+)\)s", re.IGNORECASE), + ): + ordered_parameters = [] + + def repl(m): + name = m.group(1) + value = parameters[name] + if isinstance(value, self._need_to_be_pickled): + value = pickle.dumps(value, 4).decode("latin1") + ordered_parameters.append(value) + return "?" + + operation = placeholder.sub(repl, operation) + return operation, ordered_parameters + + def __update_comment(self, table, col, comment): + key = table + "," + col + self.cursor.execute("delete from comments where key=?", [key]) + self.cursor.execute(f"insert into comments values(?, {comment})", [key]) + + __create_table = re.compile( + r"\s*create\s+table\s+`(?P\w+)`", re.IGNORECASE + ).match + + def __handle_comments( + self, + operation, + alter_table=re.compile( + r"\s*ALTER\s+TABLE\s+`(?P
\w+)`\s+" + r"SET\s+OPTIONS\(description=(?P[^)]+)\)", + re.IGNORECASE, + ).match, + options=re.compile( + r"(?P`(?P\w+)`\s+\w+|\))" r"\s+options\((?P[^)]+)\)", + re.IGNORECASE, + ), + ): + m = self.__create_table(operation) + if m: + table_name = m.group("table") + + def repl(m): + col = m.group("col") or "" + options = { + name.strip().lower(): value.strip() + for name, value in ( + o.split("=") for o in m.group("options").split(",") + ) + } + + comment = options.get("description") + if comment: + self.__update_comment(table_name, col, comment) + + return m.group("prefix") + + return options.sub(repl, operation) + + m = alter_table(operation) + if m: + table_name = m.group("table") + comment = m.group("comment") + self.__update_comment(table_name, "", comment) + return "" + + return operation + + def __handle_array_types( + self, + operation, + array_type=re.compile( + r"(?<=[(,])" r"\s*`\w+`\s+\w+<\w+>\s*" r"(?=[,)])", re.IGNORECASE + ), + ): + if self.__create_table(operation): + + def repl(m): + return m.group(0).replace("<", "_").replace(">", "_") + + return array_type.sub(repl, operation) + else: + return operation + + @staticmethod + def __parse_dateish(type_, value): + type_ = type_.lower() + if type_ == "timestamp": + type_ = "datetime" + + if type_ == "datetime": + return datetime.datetime.strptime( + value, "%Y-%m-%d %H:%M:%S.%f" if "." in value else "%Y-%m-%d %H:%M:%S", + ) + elif type_ == "date": + return datetime.date(*map(int, value.split("-"))) + elif type_ == "time": + if "." in value: + value, micro = value.split(".") + micro = [micro] + else: + micro = [] + + return datetime.time(*map(int, value.split(":") + micro)) + else: + raise AssertionError(type_) # pragma: NO COVER + + def __handle_problematic_literal_inserts( + self, + operation, + literal_insert_values=re.compile( + r"\s*(insert\s+into\s+.+\s+values\s*)" r"(\([^)]+\))" r"\s*$", re.IGNORECASE + ).match, + bq_dateish=re.compile( + r"(?<=[[(,])\s*" + r"(?Pdate(?:time)?|time(?:stamp)?) (?P'[^']+')" + r"\s*(?=[]),])", + re.IGNORECASE, + ), + need_to_be_pickled_literal=_need_to_be_pickled + (bytes,), + ): + if "?" in operation: + return operation + m = literal_insert_values(operation) + if m: + prefix, values = m.groups() + safe_globals = { + "__builtins__": { + "parse_datish": self.__parse_dateish, + "true": True, + "false": False, + } } - operation %= parameters - self.cursor.execute(operation, parameters) + + values = bq_dateish.sub(r"parse_datish('\1', \2)", values) + values = eval(values[:-1] + ",)", safe_globals) + values = ",".join( + map( + repr, + ( + ( + base64.b16encode(pickle.dumps(v, 4)).decode() + if isinstance(v, need_to_be_pickled_literal) + else v + ) + for v in values + ), + ) + ) + return f"{prefix}({values})" + else: + return operation + + def __handle_unnest( + self, operation, unnest=re.compile(r"UNNEST\(\[ ([^\]]+)? \]\)", re.IGNORECASE), + ): + return unnest.sub(r"(\1)", operation) + + def __handle_true_false(self, operation): + # Older sqlite versions, like those used on the CI servers + # don't support true and false (as aliases for 1 and 0). + return operation.replace(" true", " 1").replace(" false", " 0") + + def execute(self, operation, parameters=()): + self.connection.test_data["execute"].append((operation, parameters)) + operation, types_ = google.cloud.bigquery.dbapi.cursor._extract_types(operation) + if parameters: + operation, parameters = self.__convert_params(operation, parameters) + else: + operation = operation.replace("%%", "%") + + operation = self.__handle_comments(operation) + operation = self.__handle_array_types(operation) + operation = self.__handle_problematic_literal_inserts(operation) + operation = self.__handle_unnest(operation) + operation = self.__handle_true_false(operation) + + if operation: + try: + self.cursor.execute(operation, parameters) + except sqlite3.OperationalError as e: # pragma: NO COVER + # Help diagnose errors that shouldn't happen. + # When they do, it's likely due to sqlite versions (environment). + raise sqlite3.OperationalError( + *((operation,) + e.args + (sqlite3.sqlite_version,)) + ) + self.description = self.cursor.description self.rowcount = self.cursor.rowcount @@ -54,45 +285,150 @@ def executemany(self, operation, parameters_list): def close(self): self.cursor.close() - def fetchone(self): - return self.cursor.fetchone() + def _fix_pickled(self, row): + if row is None: + return row - def fetchmany(self, size=None): - self.cursor.fetchmany(size or self.arraysize) + return [ + ( + pickle.loads(v.encode("latin1")) + # \x80\x04 is latin-1 encoded prefix for Pickle protocol 4. + if isinstance(v, str) and v[:2] == "\x80\x04" and v[-1] == "." + else pickle.loads(base64.b16decode(v)) + # 8004 is base64 encoded prefix for Pickle protocol 4. + if isinstance(v, str) and v[:4] == "8004" and v[-2:] == "2E" + else v + ) + for d, v in zip(self.description, row) + ] + + def fetchone(self): + return self._fix_pickled(self.cursor.fetchone()) def fetchall(self): - return self.cursor.fetchall() + return map(self._fix_pickled, self.cursor) - def setinputsizes(self, sizes): - pass - def setoutputsize(self, size, column=None): - pass +class attrdict(dict): + def __setattr__(self, name, val): + self[name] = val + + def __getattr__(self, name): + if name not in self: + self[name] = attrdict() + return self[name] class FauxClient: - def __init__(self, client, connection): - self._client = client - self.project = client.project - self.connection = connection + def __init__(self, project_id=None, default_query_job_config=None, *args, **kw): + + if project_id is None: + if default_query_job_config is not None: + project_id = default_query_job_config.default_dataset.project + else: + project_id = "authproj" # we would still have gotten it from auth. + + self.project = project_id + self.tables = attrdict() + + @staticmethod + def _row_dict(row, cursor): + result = {d[0]: value for d, value in zip(cursor.description, row)} + return result + + def _get_field( + self, + type, + name=None, + notnull=None, + mode=None, + description=None, + fields=(), + columns=None, # Custom column data provided by tests. + **_, # Ignore sqlite PRAGMA data we don't care about. + ): + if columns: + custom = columns.get(name) + if custom: + return self._get_field(name=name, type=type, notnull=notnull, **custom) + + if not mode: + mode = "REQUIRED" if notnull else "NULLABLE" + + field = google.cloud.bigquery.schema.SchemaField( + name=name, + field_type=type, + mode=mode, + description=description, + fields=tuple(self._get_field(**f) for f in fields), + ) + + return field + + def __get_comments(self, cursor, table_name): + cursor.execute( + f"select key, comment" + f" from comments where key like {repr(table_name + '%')}" + ) + + return {key.split(",")[1]: comment for key, comment in cursor} def get_table(self, table_ref): + table_ref = google.cloud.bigquery.table._table_arg_to_table_ref( + table_ref, self.project + ) table_name = table_ref.table_id - with contextlib.closing(self.connection.cursor()) as cursor: - cursor.execute( - f"select name from sqlite_master" - f" where type='table' and name='{table_name}'" - ) - if list(cursor): - cursor.execute("PRAGMA table_info('{table_name}')") + with contextlib.closing(self.connection.connection.cursor()) as cursor: + cursor.execute(f"select * from sqlite_master where name='{table_name}'") + rows = list(cursor) + if rows: + table_data = self._row_dict(rows[0], cursor) + + comments = self.__get_comments(cursor, table_name) + table_comment = comments.pop("", None) + columns = getattr(self.tables, table_name).columns + for col, comment in comments.items(): + getattr(columns, col).description = comment + + cursor.execute(f"PRAGMA table_info('{table_name}')") schema = [ - google.cloud.bigquery.schema.SchemaField( - name=name, - field_type=type_, - mode="REQUIRED" if notnull else "NULLABLE", - ) - for cid, name, type_, notnull, dflt_value, pk in cursor + self._get_field(columns=columns, **self._row_dict(row, cursor)) + for row in cursor ] - return google.cloud.bigquery.table.Table(table_ref, schema) + table = google.cloud.bigquery.table.Table(table_ref, schema) + table.description = table_comment + if table_data["type"] == "view" and table_data["sql"]: + table.view_query = table_data["sql"][ + table_data["sql"].lower().index("select") : + ] + + for aname, value in self.tables.get(table_name, {}).items(): + setattr(table, aname, value) + + return table else: raise google.api_core.exceptions.NotFound(table_ref) + + def list_datasets(self): + return [ + google.cloud.bigquery.Dataset("myproject.mydataset"), + google.cloud.bigquery.Dataset("myproject.yourdataset"), + ] + + def list_tables(self, dataset): + with contextlib.closing(self.connection.connection.cursor()) as cursor: + cursor.execute("select * from sqlite_master") + return [ + google.cloud.bigquery.table.TableListItem( + dict( + tableReference=dict( + projectId=dataset.project, + datasetId=dataset.dataset_id, + tableId=row["name"], + ), + type=row["type"].upper(), + ) + ) + for row in (self._row_dict(row, cursor) for row in cursor) + if row["name"] != "comments" + ] diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py new file mode 100644 index 00000000..61190e7f --- /dev/null +++ b/tests/unit/test_api.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import mock + + +def test_dry_run(): + + with mock.patch("pybigquery._helpers.create_bigquery_client") as create_client: + import pybigquery.api + + client = pybigquery.api.ApiClient("/my/creds", "mars") + create_client.assert_called_once_with( + credentials_path="/my/creds", location="mars" + ) + client.dry_run_query("select 42") + [(name, args, kwargs)] = create_client.return_value.query.mock_calls + job_config = kwargs.pop("job_config") + assert (name, args, kwargs) == ("", (), {"query": "select 42"}) + assert job_config.dry_run + assert not job_config.use_query_cache diff --git a/tests/unit/test_catalog_functions.py b/tests/unit/test_catalog_functions.py new file mode 100644 index 00000000..0bbfad75 --- /dev/null +++ b/tests/unit/test_catalog_functions.py @@ -0,0 +1,259 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import sqlalchemy.types + + +@pytest.mark.parametrize( + "table,schema,expect", + [ + ("p.s.t", None, "p.s.t"), + ("p.s.t", "p.s", "p.s.t"), + # Why is a single schema name a project name when a table + # dataset id is given? I guess to provde a missing default. + ("p.s.t", "p", "p.s.t"), + ("s.t", "p", "p.s.t"), + ("s.t", "p.s", "p.s.t"), + ("s.t", None, "myproject.s.t"), + ("t", None, "myproject.mydataset.t"), + ("t", "s", "myproject.s.t"), + ("t", "q.s", "q.s.t"), + ], +) +def test__table_reference(faux_conn, table, schema, expect): + assert ( + str( + faux_conn.dialect._table_reference( + schema, table, faux_conn.connection._client.project + ) + ) + == expect + ) + + +@pytest.mark.parametrize( + "table,table_project,schema,schema_project", + [("p.s.t", "p", "q.s", "q"), ("p.s.t", "p", "q", "q")], +) +def test__table_reference_inconsistent_project( + faux_conn, table, table_project, schema, schema_project +): + with pytest.raises( + ValueError, + match=( + f"project_id specified in schema and table_name disagree: " + f"got {schema_project} in schema, and {table_project} in table_name" + ), + ): + faux_conn.dialect._table_reference( + schema, table, faux_conn.connection._client.project + ) + + +@pytest.mark.parametrize( + "table,table_dataset,schema,schema_dataset", + [("s.t", "s", "p.q", "q"), ("p.s.t", "s", "p.q", "q")], +) +def test__table_reference_inconsistent_dataset_id( + faux_conn, table, table_dataset, schema, schema_dataset +): + with pytest.raises( + ValueError, + match=( + f"dataset_id specified in schema and table_name disagree: " + f"got {schema_dataset} in schema, and {table_dataset} in table_name" + ), + ): + faux_conn.dialect._table_reference( + schema, table, faux_conn.connection._client.project + ) + + +@pytest.mark.parametrize("type_", ["view", "table"]) +def test_get_table_names(faux_conn, type_): + cursor = faux_conn.connection.cursor() + cursor.execute("create view view1 as select 1") + cursor.execute("create view view2 as select 2") + cursor.execute("create table table1 (x INT64)") + cursor.execute("create table table2 (x INT64)") + assert sorted(getattr(faux_conn.dialect, f"get_{type_}_names")(faux_conn)) == [ + f"{type_}{d}" for d in "12" + ] + + # once more with engine: + assert sorted( + getattr(faux_conn.dialect, f"get_{type_}_names")(faux_conn.engine) + ) == [f"{type_}{d}" for d in "12"] + + +def test_get_schema_names(faux_conn): + assert list(faux_conn.dialect.get_schema_names(faux_conn)) == [ + "mydataset", + "yourdataset", + ] + # once more with engine: + assert list(faux_conn.dialect.get_schema_names(faux_conn.engine)) == [ + "mydataset", + "yourdataset", + ] + + +def test_get_indexes(faux_conn): + from google.cloud.bigquery.table import TimePartitioning + + cursor = faux_conn.connection.cursor() + cursor.execute("create table foo (x INT64)") + assert faux_conn.dialect.get_indexes(faux_conn, "foo") == [] + + client = faux_conn.connection._client + client.tables.foo.time_partitioning = TimePartitioning(field="tm") + client.tables.foo.clustering_fields = ["user_email", "store_code"] + + assert faux_conn.dialect.get_indexes(faux_conn, "foo") == [ + dict(name="partition", column_names=["tm"], unique=False,), + dict( + name="clustering", column_names=["user_email", "store_code"], unique=False, + ), + ] + + +def test_no_table_pk_constraint(faux_conn): + # BigQuery doesn't do that. + assert faux_conn.dialect.get_pk_constraint(faux_conn, "foo") == ( + dict(constrained_columns=[]) + ) + + +def test_no_table_foreign_keys(faux_conn): + # BigQuery doesn't do that. + assert faux_conn.dialect.get_foreign_keys(faux_conn, "foo") == [] + + +def test_get_table_comment(faux_conn): + cursor = faux_conn.connection.cursor() + cursor.execute("create table foo (x INT64)") + assert faux_conn.dialect.get_table_comment(faux_conn, "foo") == (dict(text=None)) + + client = faux_conn.connection._client + client.tables.foo.description = "special table" + assert faux_conn.dialect.get_table_comment(faux_conn, "foo") == ( + dict(text="special table") + ) + + +@pytest.mark.parametrize( + "btype,atype", + [ + ("STRING", sqlalchemy.types.String), + ("BYTES", sqlalchemy.types.BINARY), + ("INT64", sqlalchemy.types.Integer), + ("FLOAT64", sqlalchemy.types.Float), + ("NUMERIC", sqlalchemy.types.DECIMAL), + ("BIGNUMERIC", sqlalchemy.types.DECIMAL), + ("BOOL", sqlalchemy.types.Boolean), + ("TIMESTAMP", sqlalchemy.types.TIMESTAMP), + ("DATE", sqlalchemy.types.DATE), + ("TIME", sqlalchemy.types.TIME), + ("DATETIME", sqlalchemy.types.DATETIME), + ("THURSDAY", sqlalchemy.types.NullType), + ], +) +def test_get_table_columns(faux_conn, btype, atype): + cursor = faux_conn.connection.cursor() + cursor.execute(f"create table foo (x {btype})") + + assert faux_conn.dialect.get_columns(faux_conn, "foo") == [ + { + "comment": None, + "default": None, + "name": "x", + "nullable": True, + "type": atype, + } + ] + + +def test_get_table_columns_special_cases(faux_conn): + cursor = faux_conn.connection.cursor() + cursor.execute("create table foo (s STRING, n INT64 not null, r RECORD)") + client = faux_conn.connection._client + client.tables.foo.columns.s.description = "a fine column" + client.tables.foo.columns.s.mode = "REPEATED" + client.tables.foo.columns.r.fields = ( + dict(name="i", type="INT64"), + dict(name="f", type="FLOAT64"), + ) + + actual = faux_conn.dialect.get_columns(faux_conn, "foo") + stype = actual[0].pop("type") + assert isinstance(stype, sqlalchemy.types.ARRAY) + assert isinstance(stype.item_type, sqlalchemy.types.String) + assert actual == [ + {"comment": "a fine column", "default": None, "name": "s", "nullable": True}, + { + "comment": None, + "default": None, + "name": "n", + "nullable": False, + "type": sqlalchemy.types.Integer, + }, + { + "comment": None, + "default": None, + "name": "r", + "nullable": True, + "type": sqlalchemy.types.JSON, + }, + { + "comment": None, + "default": None, + "name": "r.i", + "nullable": True, + "type": sqlalchemy.types.Integer, + }, + { + "comment": None, + "default": None, + "name": "r.f", + "nullable": True, + "type": sqlalchemy.types.Float, + }, + ] + + +def test_has_table(faux_conn): + cursor = faux_conn.connection.cursor() + assert not faux_conn.dialect.has_table(faux_conn, "foo") + cursor.execute("create table foo (s STRING)") + assert faux_conn.dialect.has_table(faux_conn, "foo") + # once more with engine: + assert faux_conn.dialect.has_table(faux_conn.engine, "foo") + + +def test_bad_schema_argument(faux_conn): + # with goofy schema name, to exercise some error handling + with pytest.raises(ValueError, match=r"Did not understand schema: a\.b\.c"): + faux_conn.dialect.has_table(faux_conn.engine, "foo", "a.b.c") + + +def test_bad_table_argument(faux_conn): + # with goofy table name, to exercise some error handling + with pytest.raises(ValueError, match=r"Did not understand table_name: a\.b\.c\.d"): + faux_conn.dialect.has_table(faux_conn.engine, "a.b.c.d") diff --git a/tests/unit/test_comments.py b/tests/unit/test_comments.py new file mode 100644 index 00000000..280ce989 --- /dev/null +++ b/tests/unit/test_comments.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import sqlalchemy + +from conftest import setup_table + + +def test_inline_comments(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer, comment="identifier"), + comment="a fine table", + ) + + dialect = faux_conn.dialect + assert dialect.get_table_comment(faux_conn, "some_table") == { + "text": "a fine table" + } + assert dialect.get_columns(faux_conn, "some_table")[0]["comment"] == "identifier" + + +def test_set_drop_table_comment(faux_conn): + table = setup_table( + faux_conn, "some_table", sqlalchemy.Column("id", sqlalchemy.Integer), + ) + + dialect = faux_conn.dialect + assert dialect.get_table_comment(faux_conn, "some_table") == {"text": None} + + table.comment = "a fine table" + faux_conn.execute(sqlalchemy.schema.SetTableComment(table)) + assert dialect.get_table_comment(faux_conn, "some_table") == { + "text": "a fine table" + } + + faux_conn.execute(sqlalchemy.schema.DropTableComment(table)) + assert dialect.get_table_comment(faux_conn, "some_table") == {"text": None} + + +def test_table_description_dialect_option(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + bigquery_description="a fine table", + ) + dialect = faux_conn.dialect + assert dialect.get_table_comment(faux_conn, "some_table") == { + "text": "a fine table" + } + + +def test_table_friendly_name_dialect_option(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + bigquery_friendly_name="bob", + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64 )" " OPTIONS(friendly_name='bob')" + ) + + +def test_table_friendly_name_description_dialect_option(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + bigquery_friendly_name="bob", + bigquery_description="a fine table", + ) + + dialect = faux_conn.dialect + assert dialect.get_table_comment(faux_conn, "some_table") == { + "text": "a fine table" + } + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64 )" + " OPTIONS(description='a fine table', friendly_name='bob')" + ) diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py new file mode 100644 index 00000000..f4114022 --- /dev/null +++ b/tests/unit/test_compiler.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import sqlalchemy.exc + +from conftest import setup_table + + +def test_constraints_are_ignored(faux_conn, metadata): + sqlalchemy.Table( + "ref", metadata, sqlalchemy.Column("id", sqlalchemy.Integer), + ) + sqlalchemy.Table( + "some_table", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column( + "ref_id", sqlalchemy.Integer, sqlalchemy.ForeignKey("ref.id") + ), + sqlalchemy.UniqueConstraint("id", "ref_id", name="uix_1"), + ) + metadata.create_all(faux_conn.engine) + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table`" " ( `id` INT64 NOT NULL, `ref_id` INT64 )" + ) + + +def test_compile_column(faux_conn): + table = setup_table(faux_conn, "t", sqlalchemy.Column("c", sqlalchemy.Integer)) + assert table.c.c.compile(faux_conn).string == "`c`" + + +def test_cant_compile_unnamed_column(faux_conn, metadata): + with pytest.raises( + sqlalchemy.exc.CompileError, + match="Cannot compile Column object until its 'name' is assigned.", + ): + sqlalchemy.Column(sqlalchemy.Integer).compile(faux_conn) diff --git a/tests/unit/test_compliance.py b/tests/unit/test_compliance.py new file mode 100644 index 00000000..da2390f6 --- /dev/null +++ b/tests/unit/test_compliance.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +Ported compliance tests. + +Mainly to get better unit test coverage. +""" + +import pytest +import sqlalchemy +from sqlalchemy import Column, Integer, literal_column, select, String, Table, union +from sqlalchemy.testing.assertions import eq_, in_ + +from conftest import setup_table, sqlalchemy_1_3_or_higher + + +def assert_result(connection, sel, expected): + eq_(connection.execute(sel).fetchall(), expected) + + +def some_table(connection): + return setup_table( + connection, + "some_table", + Column("id", Integer), + Column("x", Integer), + Column("y", Integer), + initial_data=[ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + +def test_distinct_selectable_in_unions(faux_conn): + table = some_table(faux_conn) + s1 = select([table]).where(table.c.id == 2).distinct() + s2 = select([table]).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + assert_result(faux_conn, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) + + +def test_limit_offset_aliased_selectable_in_unions(faux_conn): + table = some_table(faux_conn) + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + assert_result(faux_conn, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) + + +def test_percent_sign_round_trip(faux_conn, metadata): + """test that the DBAPI accommodates for escaped / nonescaped + percent signs in a way that matches the compiler + + """ + t = Table("t", metadata, Column("data", String(50))) + t.create(faux_conn.engine) + faux_conn.execute(t.insert(), dict(data="some % value")) + faux_conn.execute(t.insert(), dict(data="some %% other value")) + eq_( + faux_conn.scalar( + select([t.c.data]).where(t.c.data == literal_column("'some % value'")) + ), + "some % value", + ) + + eq_( + faux_conn.scalar( + select([t.c.data]).where( + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", + ) + + +@sqlalchemy_1_3_or_higher +def test_null_in_empty_set_is_false(faux_conn): + stmt = select( + [ + sqlalchemy.case( + [ + ( + sqlalchemy.null().in_( + sqlalchemy.bindparam("foo", value=(), expanding=True) + ), + sqlalchemy.true(), + ) + ], + else_=sqlalchemy.false(), + ) + ] + ) + in_(faux_conn.execute(stmt).fetchone()[0], (False, 0)) + + +@pytest.mark.parametrize( + "meth,arg,expected", + [ + ("contains", "b%cde", {1, 2, 3, 4, 5, 6, 7, 8, 9}), + ("startswith", "ab%c", {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + ("endswith", "e%fg", {1, 2, 3, 4, 5, 6, 7, 8, 9}), + ], +) +def test_likish(faux_conn, meth, arg, expected): + # See sqlalchemy.testing.suite.test_select.LikeFunctionsTest + table = setup_table( + faux_conn, + "t", + Column("id", Integer, primary_key=True), + Column("data", String(50)), + initial_data=[ + {"id": 1, "data": "abcdefg"}, + {"id": 2, "data": "ab/cdefg"}, + {"id": 3, "data": "ab%cdefg"}, + {"id": 4, "data": "ab_cdefg"}, + {"id": 5, "data": "abcde/fg"}, + {"id": 6, "data": "abcde%fg"}, + {"id": 7, "data": "ab#cdefg"}, + {"id": 8, "data": "ab9cdefg"}, + {"id": 9, "data": "abcde#fg"}, + {"id": 10, "data": "abcd9fg"}, + ], + ) + expr = getattr(table.c.data, meth)(arg) + rows = {value for value, in faux_conn.execute(select([table.c.id]).where(expr))} + eq_(rows, expected) + + all = {i for i in range(1, 11)} + expr = sqlalchemy.not_(expr) + rows = {value for value, in faux_conn.execute(select([table.c.id]).where(expr))} + eq_(rows, all - expected) + + +def test_group_by_composed(faux_conn): + table = setup_table( + faux_conn, + "t", + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + initial_data=[ + {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, + {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, + {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, + ], + ) + + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select([sqlalchemy.func.count(table.c.id), expr]).group_by(expr).order_by(expr) + ) + assert_result(faux_conn, stmt, [(1, 3), (1, 5), (1, 7)]) diff --git a/tests/unit/test_engine.py b/tests/unit/test_engine.py new file mode 100644 index 00000000..ad34ca08 --- /dev/null +++ b/tests/unit/test_engine.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import sqlalchemy + + +def test_engine_dataset_but_no_project(faux_conn): + engine = sqlalchemy.create_engine("bigquery:///foo") + conn = engine.connect() + assert conn.connection._client.project == "authproj" + + +def test_engine_no_dataset_no_project(faux_conn): + engine = sqlalchemy.create_engine("bigquery://") + conn = engine.connect() + assert conn.connection._client.project == "authproj" + + +@pytest.mark.parametrize("arraysize", [0, None]) +def test_set_arraysize_not_set_if_false(faux_conn, metadata, arraysize): + engine = sqlalchemy.create_engine("bigquery://", arraysize=arraysize) + sqlalchemy.Table("t", metadata, sqlalchemy.Column("c", sqlalchemy.Integer)) + conn = engine.connect() + metadata.create_all(engine) + + # Because we gave a false array size, the array size wasn't set on the cursor: + assert "arraysize" not in conn.connection.test_data + + +def test_set_arraysize(faux_conn, metadata): + engine = sqlalchemy.create_engine("bigquery://", arraysize=42) + sqlalchemy.Table("t", metadata, sqlalchemy.Column("c", sqlalchemy.Integer)) + conn = engine.connect() + metadata.create_all(engine) + + # Because we gave a false array size, the array size wasn't set on the cursor: + assert conn.connection.test_data["arraysize"] == 42 diff --git a/tests/unit/test_like_reescape.py b/tests/unit/test_like_reescape.py new file mode 100644 index 00000000..2c9bf304 --- /dev/null +++ b/tests/unit/test_like_reescape.py @@ -0,0 +1,43 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQLAlchemy and BigQuery escape % and _ differently in like expressions. + +We need to correct for the autoescape option in various string +functions. +""" + +import sqlalchemy.sql.operators +import sqlalchemy.sql.schema +import pybigquery.sqlalchemy_bigquery + + +def _check(raw, escaped, escape=None, autoescape=True): + + col = sqlalchemy.sql.schema.Column() + op = col.contains(raw, escape=escape, autoescape=autoescape) + o2 = pybigquery.sqlalchemy_bigquery.BigQueryCompiler._maybe_reescape(op) + assert o2.left.__dict__ == op.left.__dict__ + assert not o2.modifiers.get("escape") + + assert o2.right.value == escaped + + +def test_like_autoescape_reescape(): + + _check("ab%cd", "ab\\%cd") + _check("ab%c_d", "ab\\%c\\_d") + _check("ab%cd", "ab%cd", autoescape=False) + _check("ab%c_d", "ab\\%c\\_d", escape="\\") + _check("ab/%c/_/d", "ab/\\%c/\\_/d") diff --git a/tests/unit/test_parse_url.py b/tests/unit/test_parse_url.py index f50e8676..bf9f8855 100644 --- a/tests/unit/test_parse_url.py +++ b/tests/unit/test_parse_url.py @@ -83,38 +83,51 @@ def test_basic(url_with_everything): @pytest.mark.parametrize( - "param, value", + "param, value, default", [ - ("clustering_fields", ["a", "b", "c"]), - ("create_disposition", "CREATE_IF_NEEDED"), + ("clustering_fields", ["a", "b", "c"], None), + ("create_disposition", "CREATE_IF_NEEDED", None), ( "destination", TableReference( DatasetReference("different-project", "different-dataset"), "table" ), + None, ), ( "destination_encryption_configuration", lambda enc: enc.kms_key_name == EncryptionConfiguration("some-configuration").kms_key_name, + None, + ), + ("dry_run", True, None), + ("labels", {"a": "b", "c": "d"}, {}), + ("maximum_bytes_billed", 1000, None), + ("priority", "INTERACTIVE", None), + ( + "schema_update_options", + ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"], + None, ), - ("dry_run", True), - ("labels", {"a": "b", "c": "d"}), - ("maximum_bytes_billed", 1000), - ("priority", "INTERACTIVE"), - ("schema_update_options", ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"]), - ("use_query_cache", True), - ("write_disposition", "WRITE_APPEND"), + ("use_query_cache", True, None), + ("write_disposition", "WRITE_APPEND", None), ], ) -def test_all_values(url_with_everything, param, value): - job_config = parse_url(url_with_everything)[5] +def test_all_values(url_with_everything, param, value, default): + url_with_this_one = make_url("bigquery://some-project/some-dataset") + url_with_this_one.query[param] = url_with_everything.query[param] + + for url in url_with_everything, url_with_this_one: + job_config = parse_url(url)[5] + config_value = getattr(job_config, param) + if callable(value): + assert value(config_value) + else: + assert config_value == value - config_value = getattr(job_config, param) - if callable(value): - assert value(config_value) - else: - assert config_value == value + url_with_nothing = make_url("bigquery://some-project/some-dataset") + job_config = parse_url(url_with_nothing)[5] + assert getattr(job_config, param) == default @pytest.mark.parametrize( @@ -209,3 +222,16 @@ def test_not_implemented(not_implemented_arg): ) with pytest.raises(NotImplementedError): parse_url(url) + + +def test_parse_boolean(): + from pybigquery.parse_url import parse_boolean + + assert parse_boolean("true") + assert parse_boolean("True") + assert parse_boolean("TRUE") + assert not parse_boolean("false") + assert not parse_boolean("False") + assert not parse_boolean("FALSE") + with pytest.raises(ValueError): + parse_boolean("Thursday") diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index f1c9cb09..9cfb5b8b 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -1,11 +1,303 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import datetime +from decimal import Decimal + +import pytest import sqlalchemy +import pybigquery.sqlalchemy_bigquery + +from conftest import setup_table, sqlalchemy_1_3_or_higher + def test_labels_not_forced(faux_conn): - metadata = sqlalchemy.MetaData() - table = sqlalchemy.Table( - "some_table", metadata, sqlalchemy.Column("id", sqlalchemy.Integer) - ) - metadata.create_all(faux_conn.engine) + table = setup_table(faux_conn, "t", sqlalchemy.Column("id", sqlalchemy.Integer)) result = faux_conn.execute(sqlalchemy.select([table.c.id])) assert result.keys() == ["id"] # Look! Just the column name! + + +def dtrepr(v): + return f"{v.__class__.__name__.upper()} {repr(str(v))}" + + +@pytest.mark.parametrize( + "type_,val,btype,vrep", + [ + (sqlalchemy.String, "myString", "STRING", repr), + (sqlalchemy.Text, "myText", "STRING", repr), + (sqlalchemy.Unicode, "myUnicode", "STRING", repr), + (sqlalchemy.UnicodeText, "myUnicodeText", "STRING", repr), + (sqlalchemy.Integer, 424242, "INT64", repr), + (sqlalchemy.SmallInteger, 42, "INT64", repr), + (sqlalchemy.BigInteger, 1 << 60, "INT64", repr), + (sqlalchemy.Numeric, Decimal(42), "NUMERIC", str), + (sqlalchemy.Float, 4.2, "FLOAT64", repr), + ( + sqlalchemy.DateTime, + datetime.datetime(2021, 2, 3, 4, 5, 6, 123456), + "DATETIME", + dtrepr, + ), + (sqlalchemy.Date, datetime.date(2021, 2, 3), "DATE", dtrepr), + (sqlalchemy.Time, datetime.time(4, 5, 6, 123456), "TIME", dtrepr), + (sqlalchemy.Boolean, True, "BOOL", "true"), + (sqlalchemy.REAL, 1.42, "FLOAT64", repr), + (sqlalchemy.FLOAT, 0.42, "FLOAT64", repr), + (sqlalchemy.NUMERIC, Decimal(4.25), "NUMERIC", str), + (sqlalchemy.NUMERIC(39), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.NUMERIC(30, 10), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.NUMERIC(39, 10), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.DECIMAL, Decimal(0.25), "NUMERIC", str), + (sqlalchemy.DECIMAL(39), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.DECIMAL(30, 10), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.DECIMAL(39, 10), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.INTEGER, 434343, "INT64", repr), + (sqlalchemy.INT, 444444, "INT64", repr), + (sqlalchemy.SMALLINT, 43, "INT64", repr), + (sqlalchemy.BIGINT, 1 << 61, "INT64", repr), + ( + sqlalchemy.TIMESTAMP, + datetime.datetime(2021, 2, 3, 4, 5, 7, 123456), + "TIMESTAMP", + lambda v: f"TIMESTAMP {repr(str(v))}", + ), + ( + sqlalchemy.DATETIME, + datetime.datetime(2021, 2, 3, 4, 5, 8, 123456), + "DATETIME", + dtrepr, + ), + (sqlalchemy.DATE, datetime.date(2021, 2, 4), "DATE", dtrepr), + (sqlalchemy.TIME, datetime.time(4, 5, 7, 123456), "TIME", dtrepr), + (sqlalchemy.TIME, datetime.time(4, 5, 7), "TIME", dtrepr), + (sqlalchemy.TEXT, "myTEXT", "STRING", repr), + (sqlalchemy.VARCHAR, "myVARCHAR", "STRING", repr), + (sqlalchemy.NVARCHAR, "myNVARCHAR", "STRING", repr), + (sqlalchemy.CHAR, "myCHAR", "STRING", repr), + (sqlalchemy.NCHAR, "myNCHAR", "STRING", repr), + (sqlalchemy.BINARY, b"myBINARY", "BYTES", repr), + (sqlalchemy.VARBINARY, b"myVARBINARY", "BYTES", repr), + (sqlalchemy.BOOLEAN, False, "BOOL", "false"), + (sqlalchemy.ARRAY(sqlalchemy.Integer), [1, 2, 3], "ARRAY", repr), + ( + sqlalchemy.ARRAY(sqlalchemy.DATETIME), + [ + datetime.datetime(2021, 2, 3, 4, 5, 6), + datetime.datetime(2021, 2, 3, 4, 5, 7, 123456), + datetime.datetime(2021, 2, 3, 4, 5, 8, 123456), + ], + "ARRAY", + lambda a: "[" + ", ".join(dtrepr(v) for v in a) + "]", + ), + ], +) +def test_typed_parameters(faux_conn, type_, val, btype, vrep): + col_name = "foo" + table = setup_table(faux_conn, "t", sqlalchemy.Column(col_name, type_)) + + assert faux_conn.test_data["execute"].pop()[0].strip() == ( + f"CREATE TABLE `t` (\n" f"\t`{col_name}` {btype}\n" f")" + ) + + faux_conn.execute(table.insert().values(**{col_name: val})) + + if btype.startswith("ARRAY<"): + btype = btype[6:-1] + + assert faux_conn.test_data["execute"][-1] == ( + f"INSERT INTO `t` (`{col_name}`) VALUES (%({col_name}:{btype})s)", + {col_name: val}, + ) + + faux_conn.execute( + table.insert() + .values(**{col_name: sqlalchemy.literal(val, type_)}) + .compile( + dialect=pybigquery.sqlalchemy_bigquery.BigQueryDialect(), + compile_kwargs=dict(literal_binds=True), + ) + ) + + if not isinstance(vrep, str): + vrep = vrep(val) + + assert faux_conn.test_data["execute"][-1] == ( + f"INSERT INTO `t` (`{col_name}`) VALUES ({vrep})", + {}, + ) + + assert list(map(list, faux_conn.execute(sqlalchemy.select([table])))) == [[val]] * 2 + assert faux_conn.test_data["execute"][-1][0] == "SELECT `t`.`foo` \nFROM `t`" + + assert ( + list( + map( + list, + faux_conn.execute(sqlalchemy.select([table.c.foo], use_labels=True)), + ) + ) + == [[val]] * 2 + ) + assert faux_conn.test_data["execute"][-1][0] == ( + "SELECT `t`.`foo` AS `t_foo` \nFROM `t`" + ) + + +def test_select_json(faux_conn, metadata): + table = sqlalchemy.Table("t", metadata, sqlalchemy.Column("x", sqlalchemy.JSON)) + + faux_conn.ex("create table t (x RECORD)") + faux_conn.ex("""insert into t values ('{"y": 1}')""") + + row = list(faux_conn.execute(sqlalchemy.select([table])))[0] + # We expect the raw string, because sqlite3, unlike BigQuery + # doesn't deserialize for us. + assert row.x == '{"y": 1}' + + +def test_select_label_starts_w_digit(faux_conn): + # Make sure label names are legal identifiers + faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(1).label("2foo")])) + assert ( + faux_conn.test_data["execute"][-1][0] == "SELECT %(param_1:INT64)s AS `_2foo`" + ) + + +def test_force_quote(faux_conn): + from sqlalchemy.sql.elements import quoted_name + + table = setup_table( + faux_conn, "t", sqlalchemy.Column(quoted_name("foo", True), sqlalchemy.Integer), + ) + faux_conn.execute(sqlalchemy.select([table])) + assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.`foo` \nFROM `t`") + + +def test_disable_quote(faux_conn): + from sqlalchemy.sql.elements import quoted_name + + table = setup_table( + faux_conn, + "t", + sqlalchemy.Column(quoted_name("foo", False), sqlalchemy.Integer), + ) + faux_conn.execute(sqlalchemy.select([table])) + assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.foo \nFROM `t`") + + +def test_select_in_lit(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]) + ) + assert isin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s IN " + "(%(param_2:INT64)s, %(param_3:INT64)s, %(param_4:INT64)s) AS `anon_1`", + {"param_1": 1, "param_2": 1, "param_3": 2, "param_4": 3}, + ) + + +def test_select_in_param(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[1, 2, 3]), + ) + assert isin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s IN UNNEST(" + "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" + ") AS `anon_1`", + {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, + ) + + +def test_select_in_param1(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[1]), + ) + assert isin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`", + {"param_1": 1, "q_1": 1}, + ) + + +@sqlalchemy_1_3_or_higher +def test_select_in_param_empty(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[]), + ) + assert not isin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s IN UNNEST(" "[ ]" ") AS `anon_1`", + {"param_1": 1}, + ) + + +def test_select_notin_lit(faux_conn): + [[isnotin]] = faux_conn.execute( + sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]) + ) + assert isnotin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s NOT IN " + "(%(param_2:INT64)s, %(param_3:INT64)s, %(param_4:INT64)s) AS `anon_1`", + {"param_1": 0, "param_2": 1, "param_3": 2, "param_4": 3}, + ) + + +def test_select_notin_param(faux_conn): + [[isnotin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[1, 2, 3]), + ) + assert not isnotin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s NOT IN UNNEST(" + "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" + ") AS `anon_1`", + {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, + ) + + +@sqlalchemy_1_3_or_higher +def test_select_notin_param_empty(faux_conn): + [[isnotin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[]), + ) + assert isnotin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s NOT IN UNNEST(" "[ ]" ") AS `anon_1`", + {"param_1": 1}, + ) diff --git a/tests/unit/test_view.py b/tests/unit/test_view.py new file mode 100644 index 00000000..0ea943bc --- /dev/null +++ b/tests/unit/test_view.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +def test_view_definition(faux_conn): + cursor = faux_conn.connection.cursor() + cursor.execute("create view foo as select 1") + + # pass the connection: + assert faux_conn.dialect.get_view_definition(faux_conn, "foo") == "select 1" + + # pass the engine: + assert faux_conn.dialect.get_view_definition(faux_conn.engine, "foo") == "select 1" + + # remove dataset id from dialect: + faux_conn.dialect.dataset_id = None + assert ( + faux_conn.dialect.get_view_definition(faux_conn, "mydataset.foo") == "select 1" + )