diff --git a/.coveragerc b/.coveragerc index 0d8e6297..d5e3f1dc 100644 --- a/.coveragerc +++ b/.coveragerc @@ -17,8 +17,6 @@ # Generated by synthtool. DO NOT EDIT! [run] branch = True -omit = - google/cloud/__init__.py [report] fail_under = 100 @@ -35,4 +33,4 @@ omit = */proto/*.py */core/*.py */site-packages/*.py - google/cloud/__init__.py + pybigquery/requirements.py diff --git a/noxfile.py b/noxfile.py index 3ccaff8a..ec7c1e7e 100644 --- a/noxfile.py +++ b/noxfile.py @@ -28,17 +28,18 @@ BLACK_PATHS = ["docs", "pybigquery", "tests", "noxfile.py", "setup.py"] DEFAULT_PYTHON_VERSION = "3.8" -SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] +SYSTEM_TEST_PYTHON_VERSIONS = ["3.9"] UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() # 'docfx' is excluded since it only needs to run in 'docs-presubmit' nox.options.sessions = [ + "lint", "unit", - "system", "cover", - "lint", + "system", + "compliance", "lint_setup_py", "blacken", "docs", @@ -169,6 +170,58 @@ def system(session): ) +@nox.session(python=SYSTEM_TEST_PYTHON_VERSIONS) +def compliance(session): + """Run the system test suite.""" + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + system_test_folder_path = os.path.join("tests", "sqlalchemy_dialect_compliance") + + # Check the value of `RUN_SYSTEM_TESTS` env var. It defaults to true. + if os.environ.get("RUN_COMPLIANCE_TESTS", "true") == "false": + session.skip("RUN_COMPLIANCE_TESTS is set to false, skipping") + # Sanity check: Only run tests if the environment variable is set. + if not os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", ""): + session.skip("Credentials must be set via environment variable") + # Install pyopenssl for mTLS testing. + if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": + session.install("pyopenssl") + # Sanity check: only run tests if found. + if not os.path.exists(system_test_folder_path): + session.skip("Compliance tests were not found") + + # Use pre-release gRPC for system tests. + session.install("--pre", "grpcio") + + # Install all test dependencies, then install this package into the + # virtualenv's dist-packages. + session.install( + "mock", + "pytest", + "pytest-rerunfailures", + "google-cloud-testutils", + "-c", + constraints_path, + ) + session.install("-e", ".", "-c", constraints_path) + + session.run( + "py.test", + "-vv", + f"--junitxml=compliance_{session.python}_sponge_log.xml", + "--reruns=3", + "--reruns-delay=60", + "--only-rerun=" + "403 Exceeded rate limits|" + "409 Already Exists|" + "404 Not found|" + "400 Cannot execute DML over a non-existent table", + system_test_folder_path, + *session.posargs, + ) + + @nox.session(python=DEFAULT_PYTHON_VERSION) def cover(session): """Run the final coverage report. @@ -177,7 +230,7 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=50") + session.run("coverage", "report", "--show-missing", "--fail-under=100") session.run("coverage", "erase") diff --git a/pybigquery/requirements.py b/pybigquery/requirements.py new file mode 100644 index 00000000..77726faf --- /dev/null +++ b/pybigquery/requirements.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +This module is used by the compliance tests to control which tests are run + +based on database capabilities. +""" + +import sqlalchemy.testing.requirements +import sqlalchemy.testing.exclusions + +supported = sqlalchemy.testing.exclusions.open +unsupported = sqlalchemy.testing.exclusions.closed + + +class Requirements(sqlalchemy.testing.requirements.SuiteRequirements): + @property + def index_reflection(self): + return unsupported() + + @property + def indexes_with_ascdesc(self): + """target database supports CREATE INDEX with per-column ASC/DESC.""" + return unsupported() + + @property + def unique_constraint_reflection(self): + """target dialect supports reflection of unique constraints""" + return unsupported() + + @property + def autoincrement_insert(self): + """target platform generates new surrogate integer primary key values + when insert() is executed, excluding the pk column.""" + return unsupported() + + @property + def primary_key_constraint_reflection(self): + return unsupported() + + @property + def foreign_keys(self): + """Target database must support foreign keys.""" + + return unsupported() + + @property + def foreign_key_constraint_reflection(self): + return unsupported() + + @property + def on_update_cascade(self): + """target database must support ON UPDATE..CASCADE behavior in + foreign keys.""" + + return unsupported() + + @property + def named_constraints(self): + """target database must support names for constraints.""" + + return unsupported() + + @property + def temp_table_reflection(self): + return unsupported() + + @property + def temporary_tables(self): + """target database supports temporary tables""" + return unsupported() # Temporary tables require use of scripts. + + @property + def duplicate_key_raises_integrity_error(self): + """target dialect raises IntegrityError when reporting an INSERT + with a primary key violation. (hint: it should) + + """ + return unsupported() + + @property + def precision_numerics_many_significant_digits(self): + """target backend supports values with many digits on both sides, + such as 319438950232418390.273596, 87673.594069654243 + + """ + return supported() + + @property + def date_coerces_from_datetime(self): + """target dialect accepts a datetime object as the target + of a date column.""" + + # BigQuery doesn't allow saving a datetime in a date: + # `TYPE_DATE`, Invalid date: '2012-10-15T12:57:18' + + return unsupported() + + @property + def window_functions(self): + """Target database must support window functions.""" + return supported() # There are no tests for this. + + @property + def ctes(self): + """Target database supports CTEs""" + + return supported() + + @property + def views(self): + """Target database must support VIEWs.""" + + return supported() + + @property + def schemas(self): + """Target database must support external schemas, and have one + named 'test_schema'.""" + + return supported() + + @property + def implicit_default_schema(self): + """target system has a strong concept of 'default' schema that can + be referred to implicitly. + + basically, PostgreSQL. + + """ + return supported() + + @property + def comment_reflection(self): + return supported() # Well, probably not, but we'll try. :) + + @property + def unicode_ddl(self): + """Target driver must support some degree of non-ascii symbol + names. + """ + return supported() + + @property + def datetime_literals(self): + """target dialect supports rendering of a date, time, or datetime as a + literal string, e.g. via the TypeEngine.literal_processor() method. + + """ + + return supported() + + @property + def timestamp_microseconds(self): + """target dialect supports representation of Python + datetime.datetime() with microsecond objects but only + if TIMESTAMP is used.""" + return supported() + + @property + def datetime_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return supported() + + @property + def date_historic(self): + """target dialect supports representation of Python + datetime.datetime() objects with historic (pre 1970) values.""" + + return supported() + + @property + def precision_numerics_enotation_small(self): + """target backend supports Decimal() objects using E notation + to represent very small values.""" + return supported() + + @property + def precision_numerics_enotation_large(self): + """target backend supports Decimal() objects using E notation + to represent very large values.""" + return supported() + + @property + def update_from(self): + """Target must support UPDATE..FROM syntax""" + return supported() + + @property + def order_by_label_with_expression(self): + """target backend supports ORDER BY a column label within an + expression. + + Basically this:: + + select data as foo from test order by foo || 'bar' + + Lots of databases including PostgreSQL don't support this, + so this is off by default. + + """ + return supported() diff --git a/pybigquery/sqlalchemy_bigquery.py b/pybigquery/sqlalchemy_bigquery.py index 5a6ad105..764c3fc0 100644 --- a/pybigquery/sqlalchemy_bigquery.py +++ b/pybigquery/sqlalchemy_bigquery.py @@ -22,7 +22,10 @@ from __future__ import absolute_import from __future__ import unicode_literals +from decimal import Decimal +import random import operator +import uuid from google import auth import google.api_core.exceptions @@ -30,6 +33,9 @@ from google.cloud.bigquery.schema import SchemaField from google.cloud.bigquery.table import TableReference from google.api_core.exceptions import NotFound + +import sqlalchemy.sql.sqltypes +import sqlalchemy.sql.type_api from sqlalchemy.exc import NoSuchTableError from sqlalchemy import types, util from sqlalchemy.sql.compiler import ( @@ -38,10 +44,11 @@ DDLCompiler, IdentifierPreparer, ) +from sqlalchemy.sql.sqltypes import Integer, String, NullType, Numeric from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.engine.base import Engine from sqlalchemy.sql.schema import Column -from sqlalchemy.sql import elements +from sqlalchemy.sql import elements, selectable import re from .parse_url import parse_url @@ -50,24 +57,12 @@ FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+") -class UniversalSet(object): - """ - Set containing everything - https://github.com/dropbox/PyHive/blob/master/pyhive/common.py - """ - - def __contains__(self, item): - return True - - class BigQueryIdentifierPreparer(IdentifierPreparer): """ Set containing everything https://github.com/dropbox/PyHive/blob/master/pyhive/sqlalchemy_presto.py """ - reserved_words = UniversalSet() - def __init__(self, dialect): super(BigQueryIdentifierPreparer, self).__init__( dialect, initial_quote="`", @@ -88,21 +83,7 @@ def quote(self, ident, force=None, column=False): """ force = getattr(ident, "quote", None) - - if force is None: - if ident in self._strings: - return self._strings[ident] - else: - if self._requires_quotes(ident): - self._strings[ident] = ( - self.quote_column(ident) - if column - else self.quote_identifier(ident) - ) - else: - self._strings[ident] = ident - return self._strings[ident] - elif force: + if force is None or force: return self.quote_column(ident) if column else self.quote_identifier(ident) else: return ident @@ -123,8 +104,11 @@ def format_label(self, label, name=None): _type_map = { "STRING": types.String, + "BOOL": types.Boolean, "BOOLEAN": types.Boolean, + "INT64": types.Integer, "INTEGER": types.Integer, + "FLOAT64": types.Float, "FLOAT": types.Float, "TIMESTAMP": types.TIMESTAMP, "DATETIME": types.DATETIME, @@ -133,11 +117,15 @@ def format_label(self, label, name=None): "TIME": types.TIME, "RECORD": types.JSON, "NUMERIC": types.DECIMAL, + "BIGNUMERIC": types.DECIMAL, } STRING = _type_map["STRING"] +BOOL = _type_map["BOOL"] BOOLEAN = _type_map["BOOLEAN"] +INT64 = _type_map["INT64"] INTEGER = _type_map["INTEGER"] +FLOAT64 = _type_map["FLOAT64"] FLOAT = _type_map["FLOAT"] TIMESTAMP = _type_map["TIMESTAMP"] DATETIME = _type_map["DATETIME"] @@ -146,6 +134,7 @@ def format_label(self, label, name=None): TIME = _type_map["TIME"] RECORD = _type_map["RECORD"] NUMERIC = _type_map["NUMERIC"] +BIGNUMERIC = _type_map["NUMERIC"] class BigQueryExecutionContext(DefaultExecutionContext): @@ -156,8 +145,49 @@ def create_cursor(self): c.arraysize = self.dialect.arraysize return c + def get_insert_default(self, column): # pragma: NO COVER + # Only used by compliance tests + if isinstance(column.type, Integer): + return random.randint(-9223372036854775808, 9223372036854775808) # 1<<63 + elif isinstance(column.type, String): + return str(uuid.uuid4()) + + def pre_exec( + self, + in_sub=re.compile( + r" IN UNNEST\(\[ " + r"(%\([^)]+_\d+\)s(?:, %\([^)]+_\d+\)s)*)?" # Placeholders. See below. + r":([A-Z0-9]+)" # Type + r" \]\)" + ).sub, + ): + # If we have an in parameter, it sometimes gets expaned to 0 or more + # parameters and we need to move the type marker to each + # parameter. + # (The way SQLAlchemy handles this is a bit awkward for our + # purposes.) + + # In the placeholder part of the regex above, the `_\d+ + # suffixes refect that when an array parameter is expanded, + # numeric suffixes are added. For example, a placeholder like + # `%(foo)s` gets expaneded to `%(foo_0)s, `%(foo_1)s, ...`. + + def repl(m): + placeholders, type_ = m.groups() + if placeholders: + placeholders = placeholders.replace(")", f":{type_})") + else: + placeholders = "" + return f" IN UNNEST([ {placeholders} ])" + + self.statement = in_sub(repl, self.statement) + class BigQueryCompiler(SQLCompiler): + + compound_keywords = SQLCompiler.compound_keywords.copy() + compound_keywords[selectable.CompoundSelect.UNION] = "UNION ALL" + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): if isinstance(statement, Column): kwargs["compile_kwargs"] = util.immutabledict({"include_table": False}) @@ -165,10 +195,23 @@ def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs) dialect, statement, column_keys, inline, **kwargs ) + def visit_insert(self, insert_stmt, asfrom=False, **kw): + # The (internal) documentation for `inline` is confusing, but + # having `inline` be true prevents us from generating default + # primary-key values when we're doing executemany, which seem broken. + + # We can probably do this in the constructor, but I want to + # make sure this only affects insert, because I'm paranoid. :) + + self.inline = False + + return super(BigQueryCompiler, self).visit_insert( + insert_stmt, asfrom=False, **kw + ) + def visit_column( self, column, add_to_result_map=None, include_table=True, **kwargs ): - name = orig_name = column.name if name is None: name = self._fallback_column_name(column) @@ -188,16 +231,10 @@ def visit_column( if table is None or not include_table or not table.named_with_column: return name else: - effective_schema = self.preparer.schema_for_object(table) - - if effective_schema: - schema_prefix = self.preparer.quote_schema(effective_schema) + "." - else: - schema_prefix = "" tablename = table.name if isinstance(tablename, elements._truncated_label): tablename = self._truncated_identifier("alias", tablename) - return schema_prefix + self.preparer.quote(tablename) + "." + name + return self.preparer.quote(tablename) + "." + name def visit_label(self, *args, within_group_by=False, **kwargs): # Use labels in GROUP BY clause. @@ -213,31 +250,162 @@ def group_by_clause(self, select, **kw): select, **kw, within_group_by=True ) + ############################################################################ + # Handle parameters in in + + # Due to details in the way sqlalchemy arranges the compilation we + # expect the bind parameter as an array and unnest it. + + # As it happens, bigquery can handle arrays directly, but there's + # no way to tell sqlalchemy that, so it works harder than + # necessary and makes us do the same. + + _in_expanding_bind = re.compile(r" IN \((\[EXPANDING_\w+\](:[A-Z0-9]+)?)\)$") + + def _unnestify_in_expanding_bind(self, in_text): + return self._in_expanding_bind.sub(r" IN UNNEST([ \1 ])", in_text) + + def visit_in_op_binary(self, binary, operator_, **kw): + return self._unnestify_in_expanding_bind( + self._generate_generic_binary(binary, " IN ", **kw) + ) + + def visit_empty_set_expr(self, element_types): + return "" + + def visit_notin_op_binary(self, binary, operator, **kw): + return self._unnestify_in_expanding_bind( + self._generate_generic_binary(binary, " NOT IN ", **kw) + ) + + ############################################################################ + + ############################################################################ + # Correct for differences in the way that SQLAlchemy escape % and _ (/) + # and BigQuery does (\\). + + @staticmethod + def _maybe_reescape(binary): + binary = binary._clone() + escape = binary.modifiers.pop("escape", None) + if escape and escape != "\\": + binary.right.value = escape.join( + v.replace(escape, "\\") + for v in binary.right.value.split(escape + escape) + ) + return binary + + def visit_contains_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_contains_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_notcontains_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_notcontains_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_startswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_startswith_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_notstartswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_notstartswith_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_endswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_endswith_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + def visit_notendswith_op_binary(self, binary, operator, **kw): + return super(BigQueryCompiler, self).visit_notendswith_op_binary( + self._maybe_reescape(binary), operator, **kw + ) + + ############################################################################ + + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + **kwargs, + ): + param = super(BigQueryCompiler, self).visit_bindparam( + bindparam, + within_columns_clause, + literal_binds, + skip_bind_expression, + **kwargs, + ) + + type_ = bindparam.type + if isinstance(type_, NullType): + return param + + if ( + isinstance(type_, Numeric) + and (type_.precision is None or type_.scale is None) + and isinstance(bindparam.value, Decimal) + ): + t = bindparam.value.as_tuple() + + if type_.precision is None: + type_.precision = len(t.digits) + + if type_.scale is None and t.exponent < 0: + type_.scale = -t.exponent + + bq_type = self.dialect.type_compiler.process(type_) + if bq_type[-1] == ">" and bq_type.startswith("ARRAY<"): + # Values get arrayified at a lower level. + bq_type = bq_type[6:-1] + + assert param != "%s" + return param.replace(")", f":{bq_type})") + class BigQueryTypeCompiler(GenericTypeCompiler): - def visit_integer(self, type_, **kw): + def visit_INTEGER(self, type_, **kw): return "INT64" - def visit_float(self, type_, **kw): + visit_BIGINT = visit_SMALLINT = visit_INTEGER + + def visit_BOOLEAN(self, type_, **kw): + return "BOOL" + + def visit_FLOAT(self, type_, **kw): return "FLOAT64" - def visit_text(self, type_, **kw): - return "STRING" + visit_REAL = visit_FLOAT - def visit_string(self, type_, **kw): + def visit_STRING(self, type_, **kw): return "STRING" + visit_CHAR = visit_NCHAR = visit_STRING + visit_VARCHAR = visit_NVARCHAR = visit_TEXT = visit_STRING + def visit_ARRAY(self, type_, **kw): return "ARRAY<{}>".format(self.process(type_.item_type, **kw)) def visit_BINARY(self, type_, **kw): return "BYTES" + visit_VARBINARY = visit_BINARY + def visit_NUMERIC(self, type_, **kw): - return "NUMERIC" + if (type_.precision is not None and type_.precision > 38) or ( + type_.scale is not None and type_.scale > 9 + ): + return "BIGNUMERIC" + else: + return "NUMERIC" - def visit_DECIMAL(self, type_, **kw): - return "NUMERIC" + visit_DECIMAL = visit_NUMERIC class BigQueryDDLCompiler(DDLCompiler): @@ -250,31 +418,110 @@ def visit_foreign_key_constraint(self, constraint): def visit_primary_key_constraint(self, constraint): return None + # BigQuery has no support for unique constraints. + def visit_unique_constraint(self, constraint): + return None + def get_column_specification(self, column, **kwargs): colspec = super(BigQueryDDLCompiler, self).get_column_specification( column, **kwargs ) - if column.doc is not None: + if column.comment is not None: colspec = "{} OPTIONS(description={})".format( - colspec, self.preparer.quote(column.doc) + colspec, process_string_literal(column.comment) ) return colspec def post_create_table(self, table): bq_opts = table.dialect_options["bigquery"] opts = [] - if "description" in bq_opts: - opts.append( - "description={}".format(self.preparer.quote(bq_opts["description"])) + + if ("description" in bq_opts) or table.comment: + description = process_string_literal( + bq_opts.get("description", table.comment) ) + opts.append(f"description={description}") + if "friendly_name" in bq_opts: opts.append( - "friendly_name={}".format(self.preparer.quote(bq_opts["friendly_name"])) + "friendly_name={}".format( + process_string_literal(bq_opts["friendly_name"]) + ) ) + if opts: return "\nOPTIONS({})".format(", ".join(opts)) + return "" + def visit_set_table_comment(self, create): + table_name = self.preparer.format_table(create.element) + description = self.sql_compiler.render_literal_value( + create.element.comment, sqlalchemy.sql.sqltypes.String() + ) + return f"ALTER TABLE {table_name} SET OPTIONS(description={description})" + + def visit_drop_table_comment(self, drop): + table_name = self.preparer.format_table(drop.element) + return f"ALTER TABLE {table_name} SET OPTIONS(description=null)" + + +def process_string_literal(value): + return repr(value.replace("%", "%%")) + + +class BQString(String): + def literal_processor(self, dialect): + return process_string_literal + + +class BQBinary(sqlalchemy.sql.sqltypes._Binary): + @staticmethod + def __process_bytes_literal(value): + return repr(value.replace(b"%", b"%%")) + + def literal_processor(self, dialect): + return self.__process_bytes_literal + + +class BQClassTaggedStr(sqlalchemy.sql.type_api.TypeEngine): + """Type that can get literals via str + """ + + @staticmethod + def process_literal_as_class_tagged_str(value): + return f"{value.__class__.__name__.upper()} {repr(str(value))}" + + def literal_processor(self, dialect): + return self.process_literal_as_class_tagged_str + + +class BQTimestamp(sqlalchemy.sql.type_api.TypeEngine): + """Type that can get literals via str + """ + + @staticmethod + def process_timestamp_literal(value): + return f"TIMESTAMP {process_string_literal(str(value))}" + + def literal_processor(self, dialect): + return self.process_timestamp_literal + + +class BQArray(sqlalchemy.sql.sqltypes.ARRAY): + def literal_processor(self, dialect): + + item_processor = self.item_type._cached_literal_processor(dialect) + if not item_processor: + raise NotImplementedError( + f"Don't know how to literal-quote values of type {self.item_type}" + ) + + def process_array_literal(value): + return "[" + ", ".join(item_processor(v) for v in value) + "]" + + return process_array_literal + class BigQueryDialect(DefaultDialect): name = "bigquery" @@ -285,6 +532,8 @@ class BigQueryDialect(DefaultDialect): ddl_compiler = BigQueryDDLCompiler execution_ctx_cls = BigQueryExecutionContext supports_alter = False + supports_comments = True + inline_comments = True supports_pk_autoincrement = False supports_default_values = False supports_empty_insert = False @@ -297,6 +546,17 @@ class BigQueryDialect(DefaultDialect): supports_native_boolean = True supports_simple_order_by_label = True postfetch_lastrowid = False + preexecute_autoincrement_sequences = False + + colspecs = { + String: BQString, + sqlalchemy.sql.sqltypes._Binary: BQBinary, + sqlalchemy.sql.sqltypes.Date: BQClassTaggedStr, + sqlalchemy.sql.sqltypes.DateTime: BQClassTaggedStr, + sqlalchemy.sql.sqltypes.Time: BQClassTaggedStr, + sqlalchemy.sql.sqltypes.TIMESTAMP: BQTimestamp, + sqlalchemy.sql.sqltypes.ARRAY: BQArray, + } def __init__( self, @@ -305,7 +565,7 @@ def __init__( location=None, credentials_info=None, *args, - **kwargs + **kwargs, ): super(BigQueryDialect, self).__init__(*args, **kwargs) self.arraysize = arraysize @@ -427,9 +687,7 @@ def _table_reference( dataset_id_from_schema = None if provided_schema_name is not None: provided_schema_name_split = provided_schema_name.split(".") - if len(provided_schema_name_split) == 0: - pass - elif len(provided_schema_name_split) == 1: + if len(provided_schema_name_split) == 1: if dataset_id_from_table: project_id_from_schema = provided_schema_name_split[0] else: @@ -579,10 +837,7 @@ def get_schema_names(self, connection, **kw): connection = connection.connect() datasets = connection.connection._client.list_datasets() - if self.dataset_id is not None: - return [d.dataset_id for d in datasets if d.dataset_id == self.dataset_id] - else: - return [d.dataset_id for d in datasets] + return [d.dataset_id for d in datasets] def get_table_names(self, connection, schema=None, **kw): if isinstance(connection, Engine): @@ -604,6 +859,11 @@ def _check_unicode_returns(self, connection, additional_tests=None): # requests gives back Unicode strings return True - def _check_unicode_description(self, connection): - # requests gives back Unicode strings - return True + def get_view_definition(self, connection, view_name, schema=None, **kw): + if isinstance(connection, Engine): + connection = connection.connect() + client = connection.connection._client + if self.dataset_id: + view_name = f"{self.dataset_id}.{view_name}" + view = client.get_table(view_name) + return view.view_query diff --git a/setup.cfg b/setup.cfg index 95ac0e28..91fcadc7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,8 +20,11 @@ universal = 1 [sqla_testing] requirement_cls=pybigquery.requirements:Requirements -profile_file=.profiles.txt +profile_file=sqlalchemy_dialect_compliance-profiles.txt [db] -default=bigquery:// -bigquery=bigquery:// +default=bigquery:///test_pybigquery_sqla + +[tool:pytest] +addopts= --tb native -v -r fxX -p no:warnings +python_files=tests/*test_*.py diff --git a/setup.py b/setup.py index 97475956..d93f2225 100644 --- a/setup.py +++ b/setup.py @@ -65,10 +65,10 @@ def readme(): ], platforms="Posix; MacOS X; Windows", install_requires=[ - "sqlalchemy>=1.1.9,<1.4.0dev", - "google-auth>=1.14.0,<2.0dev", # Work around pip wack. - "google-cloud-bigquery>=1.12.0", - "google-api-core>=1.19.1", # Work-around bug in cloud core deps. + "sqlalchemy>=1.2.0,<1.4.0dev", + "google-auth>=1.24.0,<2.0dev", # Work around pip wack. + "google-cloud-bigquery>=2.15.0", + "google-api-core>=1.23.0", # Work-around bug in cloud core deps. "future", ], python_requires=">=3.6, <3.10", diff --git a/testing/constraints-3.6.txt b/testing/constraints-3.6.txt index 34cbdb7a..5bc8ccf5 100644 --- a/testing/constraints-3.6.txt +++ b/testing/constraints-3.6.txt @@ -4,6 +4,7 @@ # Pin the version to the lower bound. # # e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev", -sqlalchemy==1.1.9 -google-auth==1.14.0 -google-cloud-bigquery==1.12.0 +sqlalchemy==1.2.0 +google-auth==1.24.0 +google-cloud-bigquery==2.15.0 +google-api-core==1.23.0 diff --git a/tests/sqlalchemy_dialect_compliance/README.rst b/tests/sqlalchemy_dialect_compliance/README.rst new file mode 100644 index 00000000..7947ec26 --- /dev/null +++ b/tests/sqlalchemy_dialect_compliance/README.rst @@ -0,0 +1,12 @@ +SQLAlchemy Dialog Compliance Tests +================================== + +SQLAlchemy provides reusable tests that test that SQLAlchemy dialects +work properly. This directory applies these tests to the BigQuery +SQLAlchemy dialect. + +These are "system" tests, meaning that they run against a real +BigQuery account. To run the tests, you need a BigQuery account with +empty `test_pybigquery_sqla` and `test_schema` schemas. You need to +have the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to +the path of a Google Cloud authentication file. diff --git a/tests/sqlalchemy_dialect_compliance/conftest.py b/tests/sqlalchemy_dialect_compliance/conftest.py new file mode 100644 index 00000000..eefd3f07 --- /dev/null +++ b/tests/sqlalchemy_dialect_compliance/conftest.py @@ -0,0 +1,68 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from sqlalchemy.testing.plugin.pytestplugin import * # noqa +from sqlalchemy.testing.plugin.pytestplugin import ( + pytest_sessionstart as _pytest_sessionstart, +) + +import google.cloud.bigquery.dbapi.connection +import pybigquery.sqlalchemy_bigquery +import sqlalchemy +import traceback + +pybigquery.sqlalchemy_bigquery.BigQueryDialect.preexecute_autoincrement_sequences = True +google.cloud.bigquery.dbapi.connection.Connection.rollback = lambda self: None + + +# BigQuery requires delete statements to have where clauses. Other +# databases don't and sqlalchemy doesn't include where clauses when +# cleaning up test data. So we add one when we see a delete without a +# where clause when tearing down tests. We only do this during tear +# down, by inspecting the stack, because we don't want to hide bugs +# outside of test house-keeping. +def visit_delete(self, delete_stmt, *args, **kw): + if delete_stmt._whereclause is None and "teardown" in set( + f.name for f in traceback.extract_stack() + ): + delete_stmt._whereclause = sqlalchemy.true() + + return super(pybigquery.sqlalchemy_bigquery.BigQueryCompiler, self).visit_delete( + delete_stmt, *args, **kw + ) + + +pybigquery.sqlalchemy_bigquery.BigQueryCompiler.visit_delete = visit_delete + + +# Clean up test schemas so we don't get spurious errors when the tests +# try to create tables that already exist. +def pytest_sessionstart(session): + client = google.cloud.bigquery.Client() + for schema in "test_schema", "test_pybigquery_sqla": + for table_item in client.list_tables(f"{client.project}.{schema}"): + table_id = table_item.table_id + list( + client.query( + f"drop {'view' if table_id.endswith('_v') else 'table'}" + f" {schema}.{table_id}" + ).result() + ) + client.close() + _pytest_sessionstart(session) diff --git a/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py new file mode 100644 index 00000000..259a78ec --- /dev/null +++ b/tests/sqlalchemy_dialect_compliance/test_dialect_compliance.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import datetime +import mock +import pytest +import pytz +from sqlalchemy import and_ +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.suite import config, select, exists +from sqlalchemy.testing.suite import * # noqa +from sqlalchemy.testing.suite import ( + ComponentReflectionTest as _ComponentReflectionTest, + CTETest as _CTETest, + ExistsTest as _ExistsTest, + InsertBehaviorTest as _InsertBehaviorTest, + LimitOffsetTest as _LimitOffsetTest, + LongNameBlowoutTest, + QuotedNameArgumentTest, + SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest, + TimestampMicrosecondsTest as _TimestampMicrosecondsTest, +) + +# Quotes aren't allowed in BigQuery table names. +del QuotedNameArgumentTest + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip() + def test_insert_from_select_autoinc(cls): + """BQ has no autoinc and client-side defaults can't work for select.""" + + +class ExistsTest(_ExistsTest): + """ + Override + + Becaise Bigquery requires FROM when there's a WHERE and + the base tests didn't do provide a FROM. + """ + + def test_select_exists(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select([stuff.c.id]).where( + and_(stuff.c.id == 1, exists().where(stuff.c.data == "some data"),) + ) + ).fetchall(), + [(1,)], + ) + + def test_select_exists_false(self, connection): + stuff = self.tables.stuff + eq_( + connection.execute( + select([stuff.c.id]).where(exists().where(stuff.c.data == "no data")) + ).fetchall(), + [], + ) + + +class LimitOffsetTest(_LimitOffsetTest): + @pytest.mark.skip() + def test_simple_offset(self): + """BigQuery doesn't allow an offset without a limit.""" + + test_bound_offset = test_simple_offset + + +# This test requires features (indexes, primary keys, etc., that BigQuery doesn't have. +del LongNameBlowoutTest + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + """The base tests fail if operations return rows for some reason.""" + + def test_update(self): + t = self.tables.plain_pk + r = config.db.execute(t.update().where(t.c.id == 2), data="d2_new") + assert not r.is_insert + # assert not r.returns_rows + + eq_( + config.db.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self): + t = self.tables.plain_pk + r = config.db.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + # assert not r.returns_rows + eq_( + config.db.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +class CTETest(_CTETest): + @pytest.mark.skip("Can't use CTEs with insert") + def test_insert_from_select_round_trip(self): + pass + + @pytest.mark.skip("Recusive CTEs aren't supported.") + def test_select_recursive_round_trip(self): + pass + + +class ComponentReflectionTest(_ComponentReflectionTest): + @pytest.mark.skip("Big query types don't track precision, length, etc.") + def course_grained_types(): + pass + + test_numeric_reflection = test_varchar_reflection = course_grained_types + + +class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): + + data = datetime.datetime(2012, 10, 15, 12, 57, 18, 396, tzinfo=pytz.UTC) + + def test_literal(self): + # The base tests doesn't set up the literal properly, because + # it doesn't pass its datatype to `literal`. + + def literal(value): + assert value == self.data + import sqlalchemy.sql.sqltypes + + return sqlalchemy.sql.elements.literal(value, self.datatype) + + with mock.patch("sqlalchemy.testing.suite.test_types.literal", literal): + super(TimestampMicrosecondsTest, self).test_literal() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 22def748..801e84a9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,16 +1,79 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import contextlib import mock +import sqlite3 + import pytest import sqlalchemy import fauxdbi +sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split("."))) +sqlalchemy_1_3_or_higher = pytest.mark.skipif( + sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher" +) + @pytest.fixture() def faux_conn(): - with mock.patch( - "google.cloud.bigquery.dbapi.connection.Connection", fauxdbi.Connection - ): - engine = sqlalchemy.create_engine("bigquery://myproject/mydataset") - conn = engine.connect() - yield conn - conn.close() + test_data = dict(execute=[]) + connection = sqlite3.connect(":memory:") + + def factory(*args, **kw): + conn = fauxdbi.Connection(connection, test_data, *args, **kw) + return conn + + with mock.patch("google.cloud.bigquery.dbapi.connection.Connection", factory): + # We want to bypass client creation. We don't need it and it requires creds. + with mock.patch( + "pybigquery._helpers.create_bigquery_client", fauxdbi.FauxClient + ): + with mock.patch("google.auth.default", return_value=("authdb", "authproj")): + engine = sqlalchemy.create_engine("bigquery://myproject/mydataset") + conn = engine.connect() + conn.test_data = test_data + + def ex(sql, *args, **kw): + with contextlib.closing( + conn.connection.connection.connection.cursor() + ) as cursor: + cursor.execute(sql, *args, **kw) + + conn.ex = ex + + ex("create table comments" " (key string primary key, comment string)") + + yield conn + conn.close() + + +@pytest.fixture() +def metadata(): + return sqlalchemy.MetaData() + + +def setup_table(connection, name, *columns, initial_data=(), **kw): + metadata = sqlalchemy.MetaData() + table = sqlalchemy.Table(name, metadata, *columns, **kw) + metadata.create_all(connection.engine) + if initial_data: + connection.execute(table.insert(), initial_data) + return table diff --git a/tests/unit/fauxdbi.py b/tests/unit/fauxdbi.py index 44c4edae..70cbb8aa 100644 --- a/tests/unit/fauxdbi.py +++ b/tests/unit/fauxdbi.py @@ -1,49 +1,280 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import base64 +import contextlib +import datetime +import decimal +import pickle +import re +import sqlite3 + import google.api_core.exceptions import google.cloud.bigquery.schema import google.cloud.bigquery.table -import contextlib -import sqlite3 +import google.cloud.bigquery.dbapi.cursor class Connection: - - connection = None - - def __init__(self, client=None, bqstorage_client=None): - # share a single connection: - if self.connection is None: - self.__class__.connection = sqlite3.connect(":memory:") - self._client = FauxClient(client, self.connection) + def __init__(self, connection, test_data, client, *args, **kw): + self.connection = connection + self.test_data = test_data + self._client = client + client.connection = self def cursor(self): - return Cursor(self.connection) + return Cursor(self) def commit(self): pass - def rollback(self): - pass - def close(self): - self.connection.close() +class Cursor: + def __init__(self, connection): + self.connection = connection + self.cursor = connection.connection.cursor() + assert self.arraysize == 1 + __arraysize = 1 -class Cursor: + @property + def arraysize(self): + return self.__arraysize - arraysize = 1 + @arraysize.setter + def arraysize(self, v): + self.__arraysize = v + self.connection.test_data["arraysize"] = v - def __init__(self, connection): - self.connection = connection - self.cursor = connection.cursor() + # A Note on the use of pickle here + # ================================ + # + # BigQuery supports types that sqlite doesn't. We compensate by + # pickling unhandled types and saving the pickles as + # strings. Bonus: literals require extra handling. + # + # Note that this only needs to be robust enough for tests. :) So + # when reading data, we simply look for pickle protocol 4 + # prefixes, because we don't have to worry about people providing + # non-pickle string values with those prefixes, because we control + # the inputs in the tests and we choose not to do that. - def execute(self, operation, parameters=None): - if parameters: - parameters = { - name: "null" if value is None else repr(value) - for name, value in parameters.items() + _need_to_be_pickled = ( + list, + dict, + decimal.Decimal, + bool, + datetime.datetime, + datetime.date, + datetime.time, + ) + + def __convert_params( + self, + operation, + parameters, + placeholder=re.compile(r"%\((\w+)\)s", re.IGNORECASE), + ): + ordered_parameters = [] + + def repl(m): + name = m.group(1) + value = parameters[name] + if isinstance(value, self._need_to_be_pickled): + value = pickle.dumps(value, 4).decode("latin1") + ordered_parameters.append(value) + return "?" + + operation = placeholder.sub(repl, operation) + return operation, ordered_parameters + + def __update_comment(self, table, col, comment): + key = table + "," + col + self.cursor.execute("delete from comments where key=?", [key]) + self.cursor.execute(f"insert into comments values(?, {comment})", [key]) + + __create_table = re.compile( + r"\s*create\s+table\s+`(?P\w+)`", re.IGNORECASE + ).match + + def __handle_comments( + self, + operation, + alter_table=re.compile( + r"\s*ALTER\s+TABLE\s+`(?P
\w+)`\s+" + r"SET\s+OPTIONS\(description=(?P[^)]+)\)", + re.IGNORECASE, + ).match, + options=re.compile( + r"(?P`(?P\w+)`\s+\w+|\))" r"\s+options\((?P[^)]+)\)", + re.IGNORECASE, + ), + ): + m = self.__create_table(operation) + if m: + table_name = m.group("table") + + def repl(m): + col = m.group("col") or "" + options = { + name.strip().lower(): value.strip() + for name, value in ( + o.split("=") for o in m.group("options").split(",") + ) + } + + comment = options.get("description") + if comment: + self.__update_comment(table_name, col, comment) + + return m.group("prefix") + + return options.sub(repl, operation) + + m = alter_table(operation) + if m: + table_name = m.group("table") + comment = m.group("comment") + self.__update_comment(table_name, "", comment) + return "" + + return operation + + def __handle_array_types( + self, + operation, + array_type=re.compile( + r"(?<=[(,])" r"\s*`\w+`\s+\w+<\w+>\s*" r"(?=[,)])", re.IGNORECASE + ), + ): + if self.__create_table(operation): + + def repl(m): + return m.group(0).replace("<", "_").replace(">", "_") + + return array_type.sub(repl, operation) + else: + return operation + + @staticmethod + def __parse_dateish(type_, value): + type_ = type_.lower() + if type_ == "timestamp": + type_ = "datetime" + + if type_ == "datetime": + return datetime.datetime.strptime( + value, "%Y-%m-%d %H:%M:%S.%f" if "." in value else "%Y-%m-%d %H:%M:%S", + ) + elif type_ == "date": + return datetime.date(*map(int, value.split("-"))) + elif type_ == "time": + if "." in value: + value, micro = value.split(".") + micro = [micro] + else: + micro = [] + + return datetime.time(*map(int, value.split(":") + micro)) + else: + raise AssertionError(type_) # pragma: NO COVER + + def __handle_problematic_literal_inserts( + self, + operation, + literal_insert_values=re.compile( + r"\s*(insert\s+into\s+.+\s+values\s*)" r"(\([^)]+\))" r"\s*$", re.IGNORECASE + ).match, + bq_dateish=re.compile( + r"(?<=[[(,])\s*" + r"(?Pdate(?:time)?|time(?:stamp)?) (?P'[^']+')" + r"\s*(?=[]),])", + re.IGNORECASE, + ), + need_to_be_pickled_literal=_need_to_be_pickled + (bytes,), + ): + if "?" in operation: + return operation + m = literal_insert_values(operation) + if m: + prefix, values = m.groups() + safe_globals = { + "__builtins__": { + "parse_datish": self.__parse_dateish, + "true": True, + "false": False, + } } - operation %= parameters - self.cursor.execute(operation, parameters) + + values = bq_dateish.sub(r"parse_datish('\1', \2)", values) + values = eval(values[:-1] + ",)", safe_globals) + values = ",".join( + map( + repr, + ( + ( + base64.b16encode(pickle.dumps(v, 4)).decode() + if isinstance(v, need_to_be_pickled_literal) + else v + ) + for v in values + ), + ) + ) + return f"{prefix}({values})" + else: + return operation + + def __handle_unnest( + self, operation, unnest=re.compile(r"UNNEST\(\[ ([^\]]+)? \]\)", re.IGNORECASE), + ): + return unnest.sub(r"(\1)", operation) + + def __handle_true_false(self, operation): + # Older sqlite versions, like those used on the CI servers + # don't support true and false (as aliases for 1 and 0). + return operation.replace(" true", " 1").replace(" false", " 0") + + def execute(self, operation, parameters=()): + self.connection.test_data["execute"].append((operation, parameters)) + operation, types_ = google.cloud.bigquery.dbapi.cursor._extract_types(operation) + if parameters: + operation, parameters = self.__convert_params(operation, parameters) + else: + operation = operation.replace("%%", "%") + + operation = self.__handle_comments(operation) + operation = self.__handle_array_types(operation) + operation = self.__handle_problematic_literal_inserts(operation) + operation = self.__handle_unnest(operation) + operation = self.__handle_true_false(operation) + + if operation: + try: + self.cursor.execute(operation, parameters) + except sqlite3.OperationalError as e: # pragma: NO COVER + # Help diagnose errors that shouldn't happen. + # When they do, it's likely due to sqlite versions (environment). + raise sqlite3.OperationalError( + *((operation,) + e.args + (sqlite3.sqlite_version,)) + ) + self.description = self.cursor.description self.rowcount = self.cursor.rowcount @@ -54,45 +285,150 @@ def executemany(self, operation, parameters_list): def close(self): self.cursor.close() - def fetchone(self): - return self.cursor.fetchone() + def _fix_pickled(self, row): + if row is None: + return row - def fetchmany(self, size=None): - self.cursor.fetchmany(size or self.arraysize) + return [ + ( + pickle.loads(v.encode("latin1")) + # \x80\x04 is latin-1 encoded prefix for Pickle protocol 4. + if isinstance(v, str) and v[:2] == "\x80\x04" and v[-1] == "." + else pickle.loads(base64.b16decode(v)) + # 8004 is base64 encoded prefix for Pickle protocol 4. + if isinstance(v, str) and v[:4] == "8004" and v[-2:] == "2E" + else v + ) + for d, v in zip(self.description, row) + ] + + def fetchone(self): + return self._fix_pickled(self.cursor.fetchone()) def fetchall(self): - return self.cursor.fetchall() + return map(self._fix_pickled, self.cursor) - def setinputsizes(self, sizes): - pass - def setoutputsize(self, size, column=None): - pass +class attrdict(dict): + def __setattr__(self, name, val): + self[name] = val + + def __getattr__(self, name): + if name not in self: + self[name] = attrdict() + return self[name] class FauxClient: - def __init__(self, client, connection): - self._client = client - self.project = client.project - self.connection = connection + def __init__(self, project_id=None, default_query_job_config=None, *args, **kw): + + if project_id is None: + if default_query_job_config is not None: + project_id = default_query_job_config.default_dataset.project + else: + project_id = "authproj" # we would still have gotten it from auth. + + self.project = project_id + self.tables = attrdict() + + @staticmethod + def _row_dict(row, cursor): + result = {d[0]: value for d, value in zip(cursor.description, row)} + return result + + def _get_field( + self, + type, + name=None, + notnull=None, + mode=None, + description=None, + fields=(), + columns=None, # Custom column data provided by tests. + **_, # Ignore sqlite PRAGMA data we don't care about. + ): + if columns: + custom = columns.get(name) + if custom: + return self._get_field(name=name, type=type, notnull=notnull, **custom) + + if not mode: + mode = "REQUIRED" if notnull else "NULLABLE" + + field = google.cloud.bigquery.schema.SchemaField( + name=name, + field_type=type, + mode=mode, + description=description, + fields=tuple(self._get_field(**f) for f in fields), + ) + + return field + + def __get_comments(self, cursor, table_name): + cursor.execute( + f"select key, comment" + f" from comments where key like {repr(table_name + '%')}" + ) + + return {key.split(",")[1]: comment for key, comment in cursor} def get_table(self, table_ref): + table_ref = google.cloud.bigquery.table._table_arg_to_table_ref( + table_ref, self.project + ) table_name = table_ref.table_id - with contextlib.closing(self.connection.cursor()) as cursor: - cursor.execute( - f"select name from sqlite_master" - f" where type='table' and name='{table_name}'" - ) - if list(cursor): - cursor.execute("PRAGMA table_info('{table_name}')") + with contextlib.closing(self.connection.connection.cursor()) as cursor: + cursor.execute(f"select * from sqlite_master where name='{table_name}'") + rows = list(cursor) + if rows: + table_data = self._row_dict(rows[0], cursor) + + comments = self.__get_comments(cursor, table_name) + table_comment = comments.pop("", None) + columns = getattr(self.tables, table_name).columns + for col, comment in comments.items(): + getattr(columns, col).description = comment + + cursor.execute(f"PRAGMA table_info('{table_name}')") schema = [ - google.cloud.bigquery.schema.SchemaField( - name=name, - field_type=type_, - mode="REQUIRED" if notnull else "NULLABLE", - ) - for cid, name, type_, notnull, dflt_value, pk in cursor + self._get_field(columns=columns, **self._row_dict(row, cursor)) + for row in cursor ] - return google.cloud.bigquery.table.Table(table_ref, schema) + table = google.cloud.bigquery.table.Table(table_ref, schema) + table.description = table_comment + if table_data["type"] == "view" and table_data["sql"]: + table.view_query = table_data["sql"][ + table_data["sql"].lower().index("select") : + ] + + for aname, value in self.tables.get(table_name, {}).items(): + setattr(table, aname, value) + + return table else: raise google.api_core.exceptions.NotFound(table_ref) + + def list_datasets(self): + return [ + google.cloud.bigquery.Dataset("myproject.mydataset"), + google.cloud.bigquery.Dataset("myproject.yourdataset"), + ] + + def list_tables(self, dataset): + with contextlib.closing(self.connection.connection.cursor()) as cursor: + cursor.execute("select * from sqlite_master") + return [ + google.cloud.bigquery.table.TableListItem( + dict( + tableReference=dict( + projectId=dataset.project, + datasetId=dataset.dataset_id, + tableId=row["name"], + ), + type=row["type"].upper(), + ) + ) + for row in (self._row_dict(row, cursor) for row in cursor) + if row["name"] != "comments" + ] diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py new file mode 100644 index 00000000..61190e7f --- /dev/null +++ b/tests/unit/test_api.py @@ -0,0 +1,37 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import mock + + +def test_dry_run(): + + with mock.patch("pybigquery._helpers.create_bigquery_client") as create_client: + import pybigquery.api + + client = pybigquery.api.ApiClient("/my/creds", "mars") + create_client.assert_called_once_with( + credentials_path="/my/creds", location="mars" + ) + client.dry_run_query("select 42") + [(name, args, kwargs)] = create_client.return_value.query.mock_calls + job_config = kwargs.pop("job_config") + assert (name, args, kwargs) == ("", (), {"query": "select 42"}) + assert job_config.dry_run + assert not job_config.use_query_cache diff --git a/tests/unit/test_catalog_functions.py b/tests/unit/test_catalog_functions.py new file mode 100644 index 00000000..0bbfad75 --- /dev/null +++ b/tests/unit/test_catalog_functions.py @@ -0,0 +1,259 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import sqlalchemy.types + + +@pytest.mark.parametrize( + "table,schema,expect", + [ + ("p.s.t", None, "p.s.t"), + ("p.s.t", "p.s", "p.s.t"), + # Why is a single schema name a project name when a table + # dataset id is given? I guess to provde a missing default. + ("p.s.t", "p", "p.s.t"), + ("s.t", "p", "p.s.t"), + ("s.t", "p.s", "p.s.t"), + ("s.t", None, "myproject.s.t"), + ("t", None, "myproject.mydataset.t"), + ("t", "s", "myproject.s.t"), + ("t", "q.s", "q.s.t"), + ], +) +def test__table_reference(faux_conn, table, schema, expect): + assert ( + str( + faux_conn.dialect._table_reference( + schema, table, faux_conn.connection._client.project + ) + ) + == expect + ) + + +@pytest.mark.parametrize( + "table,table_project,schema,schema_project", + [("p.s.t", "p", "q.s", "q"), ("p.s.t", "p", "q", "q")], +) +def test__table_reference_inconsistent_project( + faux_conn, table, table_project, schema, schema_project +): + with pytest.raises( + ValueError, + match=( + f"project_id specified in schema and table_name disagree: " + f"got {schema_project} in schema, and {table_project} in table_name" + ), + ): + faux_conn.dialect._table_reference( + schema, table, faux_conn.connection._client.project + ) + + +@pytest.mark.parametrize( + "table,table_dataset,schema,schema_dataset", + [("s.t", "s", "p.q", "q"), ("p.s.t", "s", "p.q", "q")], +) +def test__table_reference_inconsistent_dataset_id( + faux_conn, table, table_dataset, schema, schema_dataset +): + with pytest.raises( + ValueError, + match=( + f"dataset_id specified in schema and table_name disagree: " + f"got {schema_dataset} in schema, and {table_dataset} in table_name" + ), + ): + faux_conn.dialect._table_reference( + schema, table, faux_conn.connection._client.project + ) + + +@pytest.mark.parametrize("type_", ["view", "table"]) +def test_get_table_names(faux_conn, type_): + cursor = faux_conn.connection.cursor() + cursor.execute("create view view1 as select 1") + cursor.execute("create view view2 as select 2") + cursor.execute("create table table1 (x INT64)") + cursor.execute("create table table2 (x INT64)") + assert sorted(getattr(faux_conn.dialect, f"get_{type_}_names")(faux_conn)) == [ + f"{type_}{d}" for d in "12" + ] + + # once more with engine: + assert sorted( + getattr(faux_conn.dialect, f"get_{type_}_names")(faux_conn.engine) + ) == [f"{type_}{d}" for d in "12"] + + +def test_get_schema_names(faux_conn): + assert list(faux_conn.dialect.get_schema_names(faux_conn)) == [ + "mydataset", + "yourdataset", + ] + # once more with engine: + assert list(faux_conn.dialect.get_schema_names(faux_conn.engine)) == [ + "mydataset", + "yourdataset", + ] + + +def test_get_indexes(faux_conn): + from google.cloud.bigquery.table import TimePartitioning + + cursor = faux_conn.connection.cursor() + cursor.execute("create table foo (x INT64)") + assert faux_conn.dialect.get_indexes(faux_conn, "foo") == [] + + client = faux_conn.connection._client + client.tables.foo.time_partitioning = TimePartitioning(field="tm") + client.tables.foo.clustering_fields = ["user_email", "store_code"] + + assert faux_conn.dialect.get_indexes(faux_conn, "foo") == [ + dict(name="partition", column_names=["tm"], unique=False,), + dict( + name="clustering", column_names=["user_email", "store_code"], unique=False, + ), + ] + + +def test_no_table_pk_constraint(faux_conn): + # BigQuery doesn't do that. + assert faux_conn.dialect.get_pk_constraint(faux_conn, "foo") == ( + dict(constrained_columns=[]) + ) + + +def test_no_table_foreign_keys(faux_conn): + # BigQuery doesn't do that. + assert faux_conn.dialect.get_foreign_keys(faux_conn, "foo") == [] + + +def test_get_table_comment(faux_conn): + cursor = faux_conn.connection.cursor() + cursor.execute("create table foo (x INT64)") + assert faux_conn.dialect.get_table_comment(faux_conn, "foo") == (dict(text=None)) + + client = faux_conn.connection._client + client.tables.foo.description = "special table" + assert faux_conn.dialect.get_table_comment(faux_conn, "foo") == ( + dict(text="special table") + ) + + +@pytest.mark.parametrize( + "btype,atype", + [ + ("STRING", sqlalchemy.types.String), + ("BYTES", sqlalchemy.types.BINARY), + ("INT64", sqlalchemy.types.Integer), + ("FLOAT64", sqlalchemy.types.Float), + ("NUMERIC", sqlalchemy.types.DECIMAL), + ("BIGNUMERIC", sqlalchemy.types.DECIMAL), + ("BOOL", sqlalchemy.types.Boolean), + ("TIMESTAMP", sqlalchemy.types.TIMESTAMP), + ("DATE", sqlalchemy.types.DATE), + ("TIME", sqlalchemy.types.TIME), + ("DATETIME", sqlalchemy.types.DATETIME), + ("THURSDAY", sqlalchemy.types.NullType), + ], +) +def test_get_table_columns(faux_conn, btype, atype): + cursor = faux_conn.connection.cursor() + cursor.execute(f"create table foo (x {btype})") + + assert faux_conn.dialect.get_columns(faux_conn, "foo") == [ + { + "comment": None, + "default": None, + "name": "x", + "nullable": True, + "type": atype, + } + ] + + +def test_get_table_columns_special_cases(faux_conn): + cursor = faux_conn.connection.cursor() + cursor.execute("create table foo (s STRING, n INT64 not null, r RECORD)") + client = faux_conn.connection._client + client.tables.foo.columns.s.description = "a fine column" + client.tables.foo.columns.s.mode = "REPEATED" + client.tables.foo.columns.r.fields = ( + dict(name="i", type="INT64"), + dict(name="f", type="FLOAT64"), + ) + + actual = faux_conn.dialect.get_columns(faux_conn, "foo") + stype = actual[0].pop("type") + assert isinstance(stype, sqlalchemy.types.ARRAY) + assert isinstance(stype.item_type, sqlalchemy.types.String) + assert actual == [ + {"comment": "a fine column", "default": None, "name": "s", "nullable": True}, + { + "comment": None, + "default": None, + "name": "n", + "nullable": False, + "type": sqlalchemy.types.Integer, + }, + { + "comment": None, + "default": None, + "name": "r", + "nullable": True, + "type": sqlalchemy.types.JSON, + }, + { + "comment": None, + "default": None, + "name": "r.i", + "nullable": True, + "type": sqlalchemy.types.Integer, + }, + { + "comment": None, + "default": None, + "name": "r.f", + "nullable": True, + "type": sqlalchemy.types.Float, + }, + ] + + +def test_has_table(faux_conn): + cursor = faux_conn.connection.cursor() + assert not faux_conn.dialect.has_table(faux_conn, "foo") + cursor.execute("create table foo (s STRING)") + assert faux_conn.dialect.has_table(faux_conn, "foo") + # once more with engine: + assert faux_conn.dialect.has_table(faux_conn.engine, "foo") + + +def test_bad_schema_argument(faux_conn): + # with goofy schema name, to exercise some error handling + with pytest.raises(ValueError, match=r"Did not understand schema: a\.b\.c"): + faux_conn.dialect.has_table(faux_conn.engine, "foo", "a.b.c") + + +def test_bad_table_argument(faux_conn): + # with goofy table name, to exercise some error handling + with pytest.raises(ValueError, match=r"Did not understand table_name: a\.b\.c\.d"): + faux_conn.dialect.has_table(faux_conn.engine, "a.b.c.d") diff --git a/tests/unit/test_comments.py b/tests/unit/test_comments.py new file mode 100644 index 00000000..280ce989 --- /dev/null +++ b/tests/unit/test_comments.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import sqlalchemy + +from conftest import setup_table + + +def test_inline_comments(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer, comment="identifier"), + comment="a fine table", + ) + + dialect = faux_conn.dialect + assert dialect.get_table_comment(faux_conn, "some_table") == { + "text": "a fine table" + } + assert dialect.get_columns(faux_conn, "some_table")[0]["comment"] == "identifier" + + +def test_set_drop_table_comment(faux_conn): + table = setup_table( + faux_conn, "some_table", sqlalchemy.Column("id", sqlalchemy.Integer), + ) + + dialect = faux_conn.dialect + assert dialect.get_table_comment(faux_conn, "some_table") == {"text": None} + + table.comment = "a fine table" + faux_conn.execute(sqlalchemy.schema.SetTableComment(table)) + assert dialect.get_table_comment(faux_conn, "some_table") == { + "text": "a fine table" + } + + faux_conn.execute(sqlalchemy.schema.DropTableComment(table)) + assert dialect.get_table_comment(faux_conn, "some_table") == {"text": None} + + +def test_table_description_dialect_option(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + bigquery_description="a fine table", + ) + dialect = faux_conn.dialect + assert dialect.get_table_comment(faux_conn, "some_table") == { + "text": "a fine table" + } + + +def test_table_friendly_name_dialect_option(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + bigquery_friendly_name="bob", + ) + + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64 )" " OPTIONS(friendly_name='bob')" + ) + + +def test_table_friendly_name_description_dialect_option(faux_conn): + setup_table( + faux_conn, + "some_table", + sqlalchemy.Column("id", sqlalchemy.Integer), + bigquery_friendly_name="bob", + bigquery_description="a fine table", + ) + + dialect = faux_conn.dialect + assert dialect.get_table_comment(faux_conn, "some_table") == { + "text": "a fine table" + } + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table` ( `id` INT64 )" + " OPTIONS(description='a fine table', friendly_name='bob')" + ) diff --git a/tests/unit/test_compiler.py b/tests/unit/test_compiler.py new file mode 100644 index 00000000..f4114022 --- /dev/null +++ b/tests/unit/test_compiler.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import sqlalchemy.exc + +from conftest import setup_table + + +def test_constraints_are_ignored(faux_conn, metadata): + sqlalchemy.Table( + "ref", metadata, sqlalchemy.Column("id", sqlalchemy.Integer), + ) + sqlalchemy.Table( + "some_table", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column( + "ref_id", sqlalchemy.Integer, sqlalchemy.ForeignKey("ref.id") + ), + sqlalchemy.UniqueConstraint("id", "ref_id", name="uix_1"), + ) + metadata.create_all(faux_conn.engine) + assert " ".join(faux_conn.test_data["execute"][-1][0].strip().split()) == ( + "CREATE TABLE `some_table`" " ( `id` INT64 NOT NULL, `ref_id` INT64 )" + ) + + +def test_compile_column(faux_conn): + table = setup_table(faux_conn, "t", sqlalchemy.Column("c", sqlalchemy.Integer)) + assert table.c.c.compile(faux_conn).string == "`c`" + + +def test_cant_compile_unnamed_column(faux_conn, metadata): + with pytest.raises( + sqlalchemy.exc.CompileError, + match="Cannot compile Column object until its 'name' is assigned.", + ): + sqlalchemy.Column(sqlalchemy.Integer).compile(faux_conn) diff --git a/tests/unit/test_compliance.py b/tests/unit/test_compliance.py new file mode 100644 index 00000000..da2390f6 --- /dev/null +++ b/tests/unit/test_compliance.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" +Ported compliance tests. + +Mainly to get better unit test coverage. +""" + +import pytest +import sqlalchemy +from sqlalchemy import Column, Integer, literal_column, select, String, Table, union +from sqlalchemy.testing.assertions import eq_, in_ + +from conftest import setup_table, sqlalchemy_1_3_or_higher + + +def assert_result(connection, sel, expected): + eq_(connection.execute(sel).fetchall(), expected) + + +def some_table(connection): + return setup_table( + connection, + "some_table", + Column("id", Integer), + Column("x", Integer), + Column("y", Integer), + initial_data=[ + {"id": 1, "x": 1, "y": 2}, + {"id": 2, "x": 2, "y": 3}, + {"id": 3, "x": 3, "y": 4}, + {"id": 4, "x": 4, "y": 5}, + ], + ) + + +def test_distinct_selectable_in_unions(faux_conn): + table = some_table(faux_conn) + s1 = select([table]).where(table.c.id == 2).distinct() + s2 = select([table]).where(table.c.id == 3).distinct() + + u1 = union(s1, s2).limit(2) + assert_result(faux_conn, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) + + +def test_limit_offset_aliased_selectable_in_unions(faux_conn): + table = some_table(faux_conn) + s1 = ( + select([table]) + .where(table.c.id == 2) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + s2 = ( + select([table]) + .where(table.c.id == 3) + .limit(1) + .order_by(table.c.id) + .alias() + .select() + ) + + u1 = union(s1, s2).limit(2) + assert_result(faux_conn, u1.order_by(u1.c.id), [(2, 2, 3), (3, 3, 4)]) + + +def test_percent_sign_round_trip(faux_conn, metadata): + """test that the DBAPI accommodates for escaped / nonescaped + percent signs in a way that matches the compiler + + """ + t = Table("t", metadata, Column("data", String(50))) + t.create(faux_conn.engine) + faux_conn.execute(t.insert(), dict(data="some % value")) + faux_conn.execute(t.insert(), dict(data="some %% other value")) + eq_( + faux_conn.scalar( + select([t.c.data]).where(t.c.data == literal_column("'some % value'")) + ), + "some % value", + ) + + eq_( + faux_conn.scalar( + select([t.c.data]).where( + t.c.data == literal_column("'some %% other value'") + ) + ), + "some %% other value", + ) + + +@sqlalchemy_1_3_or_higher +def test_null_in_empty_set_is_false(faux_conn): + stmt = select( + [ + sqlalchemy.case( + [ + ( + sqlalchemy.null().in_( + sqlalchemy.bindparam("foo", value=(), expanding=True) + ), + sqlalchemy.true(), + ) + ], + else_=sqlalchemy.false(), + ) + ] + ) + in_(faux_conn.execute(stmt).fetchone()[0], (False, 0)) + + +@pytest.mark.parametrize( + "meth,arg,expected", + [ + ("contains", "b%cde", {1, 2, 3, 4, 5, 6, 7, 8, 9}), + ("startswith", "ab%c", {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + ("endswith", "e%fg", {1, 2, 3, 4, 5, 6, 7, 8, 9}), + ], +) +def test_likish(faux_conn, meth, arg, expected): + # See sqlalchemy.testing.suite.test_select.LikeFunctionsTest + table = setup_table( + faux_conn, + "t", + Column("id", Integer, primary_key=True), + Column("data", String(50)), + initial_data=[ + {"id": 1, "data": "abcdefg"}, + {"id": 2, "data": "ab/cdefg"}, + {"id": 3, "data": "ab%cdefg"}, + {"id": 4, "data": "ab_cdefg"}, + {"id": 5, "data": "abcde/fg"}, + {"id": 6, "data": "abcde%fg"}, + {"id": 7, "data": "ab#cdefg"}, + {"id": 8, "data": "ab9cdefg"}, + {"id": 9, "data": "abcde#fg"}, + {"id": 10, "data": "abcd9fg"}, + ], + ) + expr = getattr(table.c.data, meth)(arg) + rows = {value for value, in faux_conn.execute(select([table.c.id]).where(expr))} + eq_(rows, expected) + + all = {i for i in range(1, 11)} + expr = sqlalchemy.not_(expr) + rows = {value for value, in faux_conn.execute(select([table.c.id]).where(expr))} + eq_(rows, all - expected) + + +def test_group_by_composed(faux_conn): + table = setup_table( + faux_conn, + "t", + Column("id", Integer, primary_key=True), + Column("x", Integer), + Column("y", Integer), + Column("q", String(50)), + Column("p", String(50)), + initial_data=[ + {"id": 1, "x": 1, "y": 2, "q": "q1", "p": "p3"}, + {"id": 2, "x": 2, "y": 3, "q": "q2", "p": "p2"}, + {"id": 3, "x": 3, "y": 4, "q": "q3", "p": "p1"}, + ], + ) + + expr = (table.c.x + table.c.y).label("lx") + stmt = ( + select([sqlalchemy.func.count(table.c.id), expr]).group_by(expr).order_by(expr) + ) + assert_result(faux_conn, stmt, [(1, 3), (1, 5), (1, 7)]) diff --git a/tests/unit/test_engine.py b/tests/unit/test_engine.py new file mode 100644 index 00000000..ad34ca08 --- /dev/null +++ b/tests/unit/test_engine.py @@ -0,0 +1,54 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import pytest +import sqlalchemy + + +def test_engine_dataset_but_no_project(faux_conn): + engine = sqlalchemy.create_engine("bigquery:///foo") + conn = engine.connect() + assert conn.connection._client.project == "authproj" + + +def test_engine_no_dataset_no_project(faux_conn): + engine = sqlalchemy.create_engine("bigquery://") + conn = engine.connect() + assert conn.connection._client.project == "authproj" + + +@pytest.mark.parametrize("arraysize", [0, None]) +def test_set_arraysize_not_set_if_false(faux_conn, metadata, arraysize): + engine = sqlalchemy.create_engine("bigquery://", arraysize=arraysize) + sqlalchemy.Table("t", metadata, sqlalchemy.Column("c", sqlalchemy.Integer)) + conn = engine.connect() + metadata.create_all(engine) + + # Because we gave a false array size, the array size wasn't set on the cursor: + assert "arraysize" not in conn.connection.test_data + + +def test_set_arraysize(faux_conn, metadata): + engine = sqlalchemy.create_engine("bigquery://", arraysize=42) + sqlalchemy.Table("t", metadata, sqlalchemy.Column("c", sqlalchemy.Integer)) + conn = engine.connect() + metadata.create_all(engine) + + # Because we gave a false array size, the array size wasn't set on the cursor: + assert conn.connection.test_data["arraysize"] == 42 diff --git a/tests/unit/test_like_reescape.py b/tests/unit/test_like_reescape.py new file mode 100644 index 00000000..2c9bf304 --- /dev/null +++ b/tests/unit/test_like_reescape.py @@ -0,0 +1,43 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQLAlchemy and BigQuery escape % and _ differently in like expressions. + +We need to correct for the autoescape option in various string +functions. +""" + +import sqlalchemy.sql.operators +import sqlalchemy.sql.schema +import pybigquery.sqlalchemy_bigquery + + +def _check(raw, escaped, escape=None, autoescape=True): + + col = sqlalchemy.sql.schema.Column() + op = col.contains(raw, escape=escape, autoescape=autoescape) + o2 = pybigquery.sqlalchemy_bigquery.BigQueryCompiler._maybe_reescape(op) + assert o2.left.__dict__ == op.left.__dict__ + assert not o2.modifiers.get("escape") + + assert o2.right.value == escaped + + +def test_like_autoescape_reescape(): + + _check("ab%cd", "ab\\%cd") + _check("ab%c_d", "ab\\%c\\_d") + _check("ab%cd", "ab%cd", autoescape=False) + _check("ab%c_d", "ab\\%c\\_d", escape="\\") + _check("ab/%c/_/d", "ab/\\%c/\\_/d") diff --git a/tests/unit/test_parse_url.py b/tests/unit/test_parse_url.py index f50e8676..bf9f8855 100644 --- a/tests/unit/test_parse_url.py +++ b/tests/unit/test_parse_url.py @@ -83,38 +83,51 @@ def test_basic(url_with_everything): @pytest.mark.parametrize( - "param, value", + "param, value, default", [ - ("clustering_fields", ["a", "b", "c"]), - ("create_disposition", "CREATE_IF_NEEDED"), + ("clustering_fields", ["a", "b", "c"], None), + ("create_disposition", "CREATE_IF_NEEDED", None), ( "destination", TableReference( DatasetReference("different-project", "different-dataset"), "table" ), + None, ), ( "destination_encryption_configuration", lambda enc: enc.kms_key_name == EncryptionConfiguration("some-configuration").kms_key_name, + None, + ), + ("dry_run", True, None), + ("labels", {"a": "b", "c": "d"}, {}), + ("maximum_bytes_billed", 1000, None), + ("priority", "INTERACTIVE", None), + ( + "schema_update_options", + ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"], + None, ), - ("dry_run", True), - ("labels", {"a": "b", "c": "d"}), - ("maximum_bytes_billed", 1000), - ("priority", "INTERACTIVE"), - ("schema_update_options", ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"]), - ("use_query_cache", True), - ("write_disposition", "WRITE_APPEND"), + ("use_query_cache", True, None), + ("write_disposition", "WRITE_APPEND", None), ], ) -def test_all_values(url_with_everything, param, value): - job_config = parse_url(url_with_everything)[5] +def test_all_values(url_with_everything, param, value, default): + url_with_this_one = make_url("bigquery://some-project/some-dataset") + url_with_this_one.query[param] = url_with_everything.query[param] + + for url in url_with_everything, url_with_this_one: + job_config = parse_url(url)[5] + config_value = getattr(job_config, param) + if callable(value): + assert value(config_value) + else: + assert config_value == value - config_value = getattr(job_config, param) - if callable(value): - assert value(config_value) - else: - assert config_value == value + url_with_nothing = make_url("bigquery://some-project/some-dataset") + job_config = parse_url(url_with_nothing)[5] + assert getattr(job_config, param) == default @pytest.mark.parametrize( @@ -209,3 +222,16 @@ def test_not_implemented(not_implemented_arg): ) with pytest.raises(NotImplementedError): parse_url(url) + + +def test_parse_boolean(): + from pybigquery.parse_url import parse_boolean + + assert parse_boolean("true") + assert parse_boolean("True") + assert parse_boolean("TRUE") + assert not parse_boolean("false") + assert not parse_boolean("False") + assert not parse_boolean("FALSE") + with pytest.raises(ValueError): + parse_boolean("Thursday") diff --git a/tests/unit/test_select.py b/tests/unit/test_select.py index f1c9cb09..9cfb5b8b 100644 --- a/tests/unit/test_select.py +++ b/tests/unit/test_select.py @@ -1,11 +1,303 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +import datetime +from decimal import Decimal + +import pytest import sqlalchemy +import pybigquery.sqlalchemy_bigquery + +from conftest import setup_table, sqlalchemy_1_3_or_higher + def test_labels_not_forced(faux_conn): - metadata = sqlalchemy.MetaData() - table = sqlalchemy.Table( - "some_table", metadata, sqlalchemy.Column("id", sqlalchemy.Integer) - ) - metadata.create_all(faux_conn.engine) + table = setup_table(faux_conn, "t", sqlalchemy.Column("id", sqlalchemy.Integer)) result = faux_conn.execute(sqlalchemy.select([table.c.id])) assert result.keys() == ["id"] # Look! Just the column name! + + +def dtrepr(v): + return f"{v.__class__.__name__.upper()} {repr(str(v))}" + + +@pytest.mark.parametrize( + "type_,val,btype,vrep", + [ + (sqlalchemy.String, "myString", "STRING", repr), + (sqlalchemy.Text, "myText", "STRING", repr), + (sqlalchemy.Unicode, "myUnicode", "STRING", repr), + (sqlalchemy.UnicodeText, "myUnicodeText", "STRING", repr), + (sqlalchemy.Integer, 424242, "INT64", repr), + (sqlalchemy.SmallInteger, 42, "INT64", repr), + (sqlalchemy.BigInteger, 1 << 60, "INT64", repr), + (sqlalchemy.Numeric, Decimal(42), "NUMERIC", str), + (sqlalchemy.Float, 4.2, "FLOAT64", repr), + ( + sqlalchemy.DateTime, + datetime.datetime(2021, 2, 3, 4, 5, 6, 123456), + "DATETIME", + dtrepr, + ), + (sqlalchemy.Date, datetime.date(2021, 2, 3), "DATE", dtrepr), + (sqlalchemy.Time, datetime.time(4, 5, 6, 123456), "TIME", dtrepr), + (sqlalchemy.Boolean, True, "BOOL", "true"), + (sqlalchemy.REAL, 1.42, "FLOAT64", repr), + (sqlalchemy.FLOAT, 0.42, "FLOAT64", repr), + (sqlalchemy.NUMERIC, Decimal(4.25), "NUMERIC", str), + (sqlalchemy.NUMERIC(39), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.NUMERIC(30, 10), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.NUMERIC(39, 10), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.DECIMAL, Decimal(0.25), "NUMERIC", str), + (sqlalchemy.DECIMAL(39), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.DECIMAL(30, 10), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.DECIMAL(39, 10), Decimal(4.25), "BIGNUMERIC", str), + (sqlalchemy.INTEGER, 434343, "INT64", repr), + (sqlalchemy.INT, 444444, "INT64", repr), + (sqlalchemy.SMALLINT, 43, "INT64", repr), + (sqlalchemy.BIGINT, 1 << 61, "INT64", repr), + ( + sqlalchemy.TIMESTAMP, + datetime.datetime(2021, 2, 3, 4, 5, 7, 123456), + "TIMESTAMP", + lambda v: f"TIMESTAMP {repr(str(v))}", + ), + ( + sqlalchemy.DATETIME, + datetime.datetime(2021, 2, 3, 4, 5, 8, 123456), + "DATETIME", + dtrepr, + ), + (sqlalchemy.DATE, datetime.date(2021, 2, 4), "DATE", dtrepr), + (sqlalchemy.TIME, datetime.time(4, 5, 7, 123456), "TIME", dtrepr), + (sqlalchemy.TIME, datetime.time(4, 5, 7), "TIME", dtrepr), + (sqlalchemy.TEXT, "myTEXT", "STRING", repr), + (sqlalchemy.VARCHAR, "myVARCHAR", "STRING", repr), + (sqlalchemy.NVARCHAR, "myNVARCHAR", "STRING", repr), + (sqlalchemy.CHAR, "myCHAR", "STRING", repr), + (sqlalchemy.NCHAR, "myNCHAR", "STRING", repr), + (sqlalchemy.BINARY, b"myBINARY", "BYTES", repr), + (sqlalchemy.VARBINARY, b"myVARBINARY", "BYTES", repr), + (sqlalchemy.BOOLEAN, False, "BOOL", "false"), + (sqlalchemy.ARRAY(sqlalchemy.Integer), [1, 2, 3], "ARRAY", repr), + ( + sqlalchemy.ARRAY(sqlalchemy.DATETIME), + [ + datetime.datetime(2021, 2, 3, 4, 5, 6), + datetime.datetime(2021, 2, 3, 4, 5, 7, 123456), + datetime.datetime(2021, 2, 3, 4, 5, 8, 123456), + ], + "ARRAY", + lambda a: "[" + ", ".join(dtrepr(v) for v in a) + "]", + ), + ], +) +def test_typed_parameters(faux_conn, type_, val, btype, vrep): + col_name = "foo" + table = setup_table(faux_conn, "t", sqlalchemy.Column(col_name, type_)) + + assert faux_conn.test_data["execute"].pop()[0].strip() == ( + f"CREATE TABLE `t` (\n" f"\t`{col_name}` {btype}\n" f")" + ) + + faux_conn.execute(table.insert().values(**{col_name: val})) + + if btype.startswith("ARRAY<"): + btype = btype[6:-1] + + assert faux_conn.test_data["execute"][-1] == ( + f"INSERT INTO `t` (`{col_name}`) VALUES (%({col_name}:{btype})s)", + {col_name: val}, + ) + + faux_conn.execute( + table.insert() + .values(**{col_name: sqlalchemy.literal(val, type_)}) + .compile( + dialect=pybigquery.sqlalchemy_bigquery.BigQueryDialect(), + compile_kwargs=dict(literal_binds=True), + ) + ) + + if not isinstance(vrep, str): + vrep = vrep(val) + + assert faux_conn.test_data["execute"][-1] == ( + f"INSERT INTO `t` (`{col_name}`) VALUES ({vrep})", + {}, + ) + + assert list(map(list, faux_conn.execute(sqlalchemy.select([table])))) == [[val]] * 2 + assert faux_conn.test_data["execute"][-1][0] == "SELECT `t`.`foo` \nFROM `t`" + + assert ( + list( + map( + list, + faux_conn.execute(sqlalchemy.select([table.c.foo], use_labels=True)), + ) + ) + == [[val]] * 2 + ) + assert faux_conn.test_data["execute"][-1][0] == ( + "SELECT `t`.`foo` AS `t_foo` \nFROM `t`" + ) + + +def test_select_json(faux_conn, metadata): + table = sqlalchemy.Table("t", metadata, sqlalchemy.Column("x", sqlalchemy.JSON)) + + faux_conn.ex("create table t (x RECORD)") + faux_conn.ex("""insert into t values ('{"y": 1}')""") + + row = list(faux_conn.execute(sqlalchemy.select([table])))[0] + # We expect the raw string, because sqlite3, unlike BigQuery + # doesn't deserialize for us. + assert row.x == '{"y": 1}' + + +def test_select_label_starts_w_digit(faux_conn): + # Make sure label names are legal identifiers + faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(1).label("2foo")])) + assert ( + faux_conn.test_data["execute"][-1][0] == "SELECT %(param_1:INT64)s AS `_2foo`" + ) + + +def test_force_quote(faux_conn): + from sqlalchemy.sql.elements import quoted_name + + table = setup_table( + faux_conn, "t", sqlalchemy.Column(quoted_name("foo", True), sqlalchemy.Integer), + ) + faux_conn.execute(sqlalchemy.select([table])) + assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.`foo` \nFROM `t`") + + +def test_disable_quote(faux_conn): + from sqlalchemy.sql.elements import quoted_name + + table = setup_table( + faux_conn, + "t", + sqlalchemy.Column(quoted_name("foo", False), sqlalchemy.Integer), + ) + faux_conn.execute(sqlalchemy.select([table])) + assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.foo \nFROM `t`") + + +def test_select_in_lit(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]) + ) + assert isin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s IN " + "(%(param_2:INT64)s, %(param_3:INT64)s, %(param_4:INT64)s) AS `anon_1`", + {"param_1": 1, "param_2": 1, "param_3": 2, "param_4": 3}, + ) + + +def test_select_in_param(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[1, 2, 3]), + ) + assert isin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s IN UNNEST(" + "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" + ") AS `anon_1`", + {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, + ) + + +def test_select_in_param1(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[1]), + ) + assert isin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`", + {"param_1": 1, "q_1": 1}, + ) + + +@sqlalchemy_1_3_or_higher +def test_select_in_param_empty(faux_conn): + [[isin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[]), + ) + assert not isin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s IN UNNEST(" "[ ]" ") AS `anon_1`", + {"param_1": 1}, + ) + + +def test_select_notin_lit(faux_conn): + [[isnotin]] = faux_conn.execute( + sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]) + ) + assert isnotin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s NOT IN " + "(%(param_2:INT64)s, %(param_3:INT64)s, %(param_4:INT64)s) AS `anon_1`", + {"param_1": 0, "param_2": 1, "param_3": 2, "param_4": 3}, + ) + + +def test_select_notin_param(faux_conn): + [[isnotin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[1, 2, 3]), + ) + assert not isnotin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s NOT IN UNNEST(" + "[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]" + ") AS `anon_1`", + {"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3}, + ) + + +@sqlalchemy_1_3_or_higher +def test_select_notin_param_empty(faux_conn): + [[isnotin]] = faux_conn.execute( + sqlalchemy.select( + [sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))] + ), + dict(q=[]), + ) + assert isnotin + assert faux_conn.test_data["execute"][-1] == ( + "SELECT %(param_1:INT64)s NOT IN UNNEST(" "[ ]" ") AS `anon_1`", + {"param_1": 1}, + ) diff --git a/tests/unit/test_view.py b/tests/unit/test_view.py new file mode 100644 index 00000000..0ea943bc --- /dev/null +++ b/tests/unit/test_view.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 The PyBigQuery Authors +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +def test_view_definition(faux_conn): + cursor = faux_conn.connection.cursor() + cursor.execute("create view foo as select 1") + + # pass the connection: + assert faux_conn.dialect.get_view_definition(faux_conn, "foo") == "select 1" + + # pass the engine: + assert faux_conn.dialect.get_view_definition(faux_conn.engine, "foo") == "select 1" + + # remove dataset id from dialect: + faux_conn.dialect.dataset_id = None + assert ( + faux_conn.dialect.get_view_definition(faux_conn, "mydataset.foo") == "select 1" + )