diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py new file mode 100644 index 0000000000..e94ecdc0ed --- /dev/null +++ b/google/cloud/spanner_dbapi/__init__.py @@ -0,0 +1,93 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Connection-based DB API for Cloud Spanner.""" + +from google.cloud.spanner_dbapi.connection import Connection +from google.cloud.spanner_dbapi.connection import connect + +from google.cloud.spanner_dbapi.cursor import Cursor + +from google.cloud.spanner_dbapi.exceptions import DatabaseError +from google.cloud.spanner_dbapi.exceptions import DataError +from google.cloud.spanner_dbapi.exceptions import Error +from google.cloud.spanner_dbapi.exceptions import IntegrityError +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.exceptions import InternalError +from google.cloud.spanner_dbapi.exceptions import NotSupportedError +from google.cloud.spanner_dbapi.exceptions import OperationalError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError +from google.cloud.spanner_dbapi.exceptions import Warning + +from google.cloud.spanner_dbapi.parse_utils import get_param_types + +from google.cloud.spanner_dbapi.types import BINARY +from google.cloud.spanner_dbapi.types import DATETIME +from google.cloud.spanner_dbapi.types import NUMBER +from google.cloud.spanner_dbapi.types import ROWID +from google.cloud.spanner_dbapi.types import STRING +from google.cloud.spanner_dbapi.types import Binary +from google.cloud.spanner_dbapi.types import Date +from google.cloud.spanner_dbapi.types import DateFromTicks +from google.cloud.spanner_dbapi.types import Time +from google.cloud.spanner_dbapi.types import TimeFromTicks +from google.cloud.spanner_dbapi.types import Timestamp +from google.cloud.spanner_dbapi.types import TimestampStr +from google.cloud.spanner_dbapi.types import TimestampFromTicks + +from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT + +apilevel = "2.0" # supports DP-API 2.0 level. +paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. + +# Threads may share the module, but not connections. This is a paranoid threadsafety +# level, but it is necessary for starters to use when debugging failures. +# Eventually once transactions are working properly, we'll update the +# threadsafety level. +threadsafety = 1 + + +__all__ = [ + "Connection", + "connect", + "Cursor", + "DatabaseError", + "DataError", + "Error", + "IntegrityError", + "InterfaceError", + "InternalError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "Warning", + "DEFAULT_USER_AGENT", + "apilevel", + "paramstyle", + "threadsafety", + "get_param_types", + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "BINARY", + "STRING", + "NUMBER", + "DATETIME", + "ROWID", + "TimestampStr", +] diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py new file mode 100644 index 0000000000..2fcdd59137 --- /dev/null +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -0,0 +1,158 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.parse_utils import parse_insert +from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner +from google.cloud.spanner_v1 import param_types + + +SQL_LIST_TABLES = """ + SELECT + t.table_name + FROM + information_schema.tables AS t + WHERE + t.table_catalog = '' and t.table_schema = '' + """ + +SQL_GET_TABLE_COLUMN_SCHEMA = """SELECT + COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = '' + AND + TABLE_NAME = @table_name + """ + +# This table maps spanner_types to Spanner's data type sizes as per +# https://cloud.google.com/spanner/docs/data-types#allowable-types +# It is used to map `display_size` to a known type for Cursor.description +# after a row fetch. +# Since ResultMetadata +# https://cloud.google.com/spanner/docs/reference/rest/v1/ResultSetMetadata +# does not send back the actual size, we have to lookup the respective size. +# Some fields' sizes are dependent upon the dynamic data hence aren't sent back +# by Cloud Spanner. +code_to_display_size = { + param_types.BOOL.code: 1, + param_types.DATE.code: 4, + param_types.FLOAT64.code: 8, + param_types.INT64.code: 8, + param_types.TIMESTAMP.code: 12, +} + + +def _execute_insert_heterogenous(transaction, sql_params_list): + for sql, params in sql_params_list: + sql, params = sql_pyformat_args_to_spanner(sql, params) + param_types = get_param_types(params) + transaction.execute_update(sql, params=params, param_types=param_types) + + +def _execute_insert_homogenous(transaction, parts): + # Perform an insert in one shot. + table = parts.get("table") + columns = parts.get("columns") + values = parts.get("values") + return transaction.insert(table, columns, values) + + +def handle_insert(connection, sql, params): + parts = parse_insert(sql, params) + + # The split between the two styles exists because: + # in the common case of multiple values being passed + # with simple pyformat arguments, + # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s) + # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)] + # we can take advantage of a single RPC with: + # transaction.insert(table, columns, values) + # instead of invoking: + # with transaction: + # for sql, params in sql_params_list: + # transaction.execute_sql(sql, params, param_types) + # which invokes more RPCs and is more costly. + + if parts.get("homogenous"): + # The common case of multiple values being passed in + # non-complex pyformat args and need to be uploaded in one RPC. + return connection.database.run_in_transaction(_execute_insert_homogenous, parts) + else: + # All the other cases that are esoteric and need + # transaction.execute_sql + sql_params_list = parts.get("sql_params_list") + return connection.database.run_in_transaction( + _execute_insert_heterogenous, sql_params_list + ) + + +class ColumnInfo: + """Row column description object.""" + + def __init__( + self, + name, + type_code, + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=False, + ): + self.name = name + self.type_code = type_code + self.display_size = display_size + self.internal_size = internal_size + self.precision = precision + self.scale = scale + self.null_ok = null_ok + + self.fields = ( + self.name, + self.type_code, + self.display_size, + self.internal_size, + self.precision, + self.scale, + self.null_ok, + ) + + def __repr__(self): + return self.__str__() + + def __getitem__(self, index): + return self.fields[index] + + def __str__(self): + str_repr = ", ".join( + filter( + lambda part: part is not None, + [ + "name='%s'" % self.name, + "type_code=%d" % self.type_code, + "display_size=%d" % self.display_size + if self.display_size + else None, + "internal_size=%d" % self.internal_size + if self.internal_size + else None, + "precision='%s'" % self.precision if self.precision else None, + "scale='%s'" % self.scale if self.scale else None, + "null_ok='%s'" % self.null_ok if self.null_ok else None, + ], + ) + ) + return "ColumnInfo(%s)" % str_repr diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py new file mode 100644 index 0000000000..befc760ea5 --- /dev/null +++ b/google/cloud/spanner_dbapi/connection.py @@ -0,0 +1,264 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DB-API Connection for the Google Cloud Spanner.""" + +import warnings + +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud import spanner_v1 as spanner + +from google.cloud.spanner_dbapi.cursor import Cursor +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT +from google.cloud.spanner_dbapi.version import PY_VERSION + + +AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" + + +class Connection: + """Representation of a DB-API connection to a Cloud Spanner database. + + You most likely don't need to instantiate `Connection` objects + directly, use the `connect` module function instead. + + :type instance: :class:`~google.cloud.spanner_v1.instance.Instance` + :param instance: Cloud Spanner instance to connect to. + + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: The database to which the connection is linked. + """ + + def __init__(self, instance, database): + self._instance = instance + self._database = database + self._ddl_statements = [] + + self._transaction = None + self._session = None + + self.is_closed = False + self._autocommit = False + + @property + def autocommit(self): + """Autocommit mode flag for this connection. + + :rtype: bool + :returns: Autocommit mode flag value. + """ + return self._autocommit + + @autocommit.setter + def autocommit(self, value): + """Change this connection autocommit mode. Setting this value to True + while a transaction is active will commit the current transaction. + + :type value: bool + :param value: New autocommit mode state. + """ + if value and not self._autocommit: + self.commit() + + self._autocommit = value + + @property + def database(self): + """Database to which this connection relates. + + :rtype: :class:`~google.cloud.spanner_v1.database.Database` + :returns: The related database object. + """ + return self._database + + @property + def instance(self): + """Instance to which this connection relates. + + :rtype: :class:`~google.cloud.spanner_v1.instance.Instance` + :returns: The related instance object. + """ + return self._instance + + def _session_checkout(self): + """Get a Cloud Spanner session from the pool. + + If there is already a session associated with + this connection, it'll be used instead. + + :rtype: :class:`google.cloud.spanner_v1.session.Session` + :returns: Cloud Spanner session object ready to use. + """ + if not self._session: + self._session = self.database._pool.get() + + return self._session + + def _release_session(self): + """Release the currently used Spanner session. + + The session will be returned into the sessions pool. + """ + self.database._pool.put(self._session) + self._session = None + + def transaction_checkout(self): + """Get a Cloud Spanner transaction. + + Begin a new transaction, if there is no transaction in + this connection yet. Return the begun one otherwise. + + The method is non operational in autocommit mode. + + :rtype: :class:`google.cloud.spanner_v1.transaction.Transaction` + :returns: A Cloud Spanner transaction object, ready to use. + """ + if not self.autocommit: + if ( + not self._transaction + or self._transaction.committed + or self._transaction.rolled_back + ): + self._transaction = self._session_checkout().transaction() + self._transaction.begin() + + return self._transaction + + def _raise_if_closed(self): + """Helper to check the connection state before running a query. + Raises an exception if this connection is closed. + + :raises: :class:`InterfaceError`: if this connection is closed. + """ + if self.is_closed: + raise InterfaceError("connection is already closed") + + def close(self): + """Closes this connection. + + The connection will be unusable from this point forward. If the + connection has an active transaction, it will be rolled back. + """ + if ( + self._transaction + and not self._transaction.committed + and not self._transaction.rolled_back + ): + self._transaction.rollback() + + self.is_closed = True + + def commit(self): + """Commits any pending transaction to the database. + + This method is non-operational in autocommit mode. + """ + if self._autocommit: + warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + elif self._transaction: + self._transaction.commit() + self._release_session() + + def rollback(self): + """Rolls back any pending transaction. + + This is a no-op if there is no active transaction or if the connection + is in autocommit mode. + """ + if self._autocommit: + warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + elif self._transaction: + self._transaction.rollback() + self._release_session() + + def cursor(self): + """Factory to create a DB-API Cursor.""" + self._raise_if_closed() + + return Cursor(self) + + def run_prior_DDL_statements(self): + self._raise_if_closed() + + if self._ddl_statements: + ddl_statements = self._ddl_statements + self._ddl_statements = [] + + return self.database.update_ddl(ddl_statements).result() + + def __enter__(self): + return self + + def __exit__(self, etype, value, traceback): + self.commit() + self.close() + + +def connect( + instance_id, database_id, project=None, credentials=None, pool=None, user_agent=None +): + """Creates a connection to a Google Cloud Spanner database. + + :type instance_id: str + :param instance_id: The ID of the instance to connect to. + + :type database_id: str + :param database_id: The ID of the database to connect to. + + :type project: str + :param project: (Optional) The ID of the project which owns the + instances, tables and data. If not provided, will + attempt to determine from the environment. + + :type credentials: :class:`~google.auth.credentials.Credentials` + :param credentials: (Optional) The authorization credentials to attach to + requests. These credentials identify this application + to the service. If none are specified, the client will + attempt to ascertain the credentials from the + environment. + + :type pool: Concrete subclass of + :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. + :param pool: (Optional). Session pool to be used by database. + + :type user_agent: str + :param user_agent: (Optional) User agent to be used with this connection's + requests. + + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` + :returns: Connection object associated with the given Google Cloud Spanner + resource. + + :raises: :class:`ValueError` in case of given instance/database + doesn't exist. + """ + + client_info = ClientInfo( + user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION + ) + + client = spanner.Client( + project=project, credentials=credentials, client_info=client_info + ) + + instance = client.instance(instance_id) + if not instance.exists(): + raise ValueError("instance '%s' does not exist." % instance_id) + + database = instance.database(database_id, pool=pool) + if not database.exists(): + raise ValueError("database '%s' does not exist." % database_id) + + return Connection(instance, database) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py new file mode 100644 index 0000000000..ceaccccdf3 --- /dev/null +++ b/google/cloud/spanner_dbapi/cursor.py @@ -0,0 +1,333 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Database cursor for Google Cloud Spanner DB-API.""" + +from google.api_core.exceptions import AlreadyExists +from google.api_core.exceptions import FailedPrecondition +from google.api_core.exceptions import InternalServerError +from google.api_core.exceptions import InvalidArgument + +from collections import namedtuple + +from google.cloud import spanner_v1 as spanner + +from google.cloud.spanner_dbapi.exceptions import IntegrityError +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.exceptions import OperationalError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError + +from google.cloud.spanner_dbapi import _helpers +from google.cloud.spanner_dbapi._helpers import ColumnInfo +from google.cloud.spanner_dbapi._helpers import code_to_display_size + +from google.cloud.spanner_dbapi import parse_utils +from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.utils import PeekIterator + +_UNSET_COUNT = -1 + +ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) + + +class Cursor(object): + """Database cursor to manage the context of a fetch operation. + + :type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection` + :param connection: A DB-API connection to Google Cloud Spanner. + """ + + def __init__(self, connection): + self._itr = None + self._result_set = None + self._row_count = _UNSET_COUNT + self.connection = connection + self._is_closed = False + + # the number of rows to fetch at a time with fetchmany() + self.arraysize = 1 + + @property + def is_closed(self): + """The cursor close indicator. + + :rtype: bool + :returns: True if the cursor or the parent connection is closed, + otherwise False. + """ + return self._is_closed or self.connection.is_closed + + @property + def description(self): + """Read-only attribute containing a sequence of the following items: + + - ``name`` + - ``type_code`` + - ``display_size`` + - ``internal_size`` + - ``precision`` + - ``scale`` + - ``null_ok`` + """ + if not (self._result_set and self._result_set.metadata): + return None + + row_type = self._result_set.metadata.row_type + columns = [] + + for field in row_type.fields: + column_info = ColumnInfo( + name=field.name, + type_code=field.type.code, + # Size of the SQL type of the column. + display_size=code_to_display_size.get(field.type.code), + # Client perceived size of the column. + internal_size=field.ByteSize(), + ) + columns.append(column_info) + + return tuple(columns) + + @property + def rowcount(self): + """The number of rows produced by the last `.execute()`.""" + return self._row_count + + def _raise_if_closed(self): + """Raise an exception if this cursor is closed. + + Helper to check this cursor's state before running a + SQL/DDL/DML query. If the parent connection is + already closed it also raises an error. + + :raises: :class:`InterfaceError` if this cursor is closed. + """ + if self.is_closed: + raise InterfaceError("Cursor and/or connection is already closed.") + + def callproc(self, procname, args=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def close(self): + """Closes this Cursor, making it unusable from this point forward.""" + self._is_closed = True + + def _do_execute_update(self, transaction, sql, params, param_types=None): + sql = parse_utils.ensure_where_clause(sql) + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) + + result = transaction.execute_update( + sql, params=params, param_types=get_param_types(params) + ) + self._itr = None + if type(result) == int: + self._row_count = result + + return result + + def execute(self, sql, args=None): + """Prepares and executes a Spanner database operation. + + :type sql: str + :param sql: A SQL query statement. + + :type args: list + :param args: Additional parameters to supplement the SQL query. + """ + if not self.connection: + raise ProgrammingError("Cursor is not connected to the database") + + self._raise_if_closed() + + self._result_set = None + + # Classify whether this is a read-only SQL statement. + try: + classification = parse_utils.classify_stmt(sql) + if classification == parse_utils.STMT_DDL: + self.connection._ddl_statements.append(sql) + return + + # For every other operation, we've got to ensure that + # any prior DDL statements were run. + # self._run_prior_DDL_statements() + self.connection.run_prior_DDL_statements() + + if not self.connection.autocommit: + transaction = self.connection.transaction_checkout() + + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, args) + + self._result_set = transaction.execute_sql( + sql, params, param_types=get_param_types(params) + ) + self._itr = PeekIterator(self._result_set) + return + + if classification == parse_utils.STMT_NON_UPDATING: + self._handle_DQL(sql, args or None) + elif classification == parse_utils.STMT_INSERT: + _helpers.handle_insert(self.connection, sql, args or None) + else: + self.connection.database.run_in_transaction( + self._do_execute_update, sql, args or None + ) + except (AlreadyExists, FailedPrecondition) as e: + raise IntegrityError(e.details if hasattr(e, "details") else e) + except InvalidArgument as e: + raise ProgrammingError(e.details if hasattr(e, "details") else e) + except InternalServerError as e: + raise OperationalError(e.details if hasattr(e, "details") else e) + + def executemany(self, operation, seq_of_params): + """Execute the given SQL with every parameters set + from the given sequence of parameters. + + :type operation: str + :param operation: SQL code to execute. + + :type seq_of_params: list + :param seq_of_params: Sequence of additional parameters to run + the query with. + """ + self._raise_if_closed() + + for params in seq_of_params: + self.execute(operation, params) + + def fetchone(self): + """Fetch the next row of a query result set, returning a single + sequence, or None when no more data is available.""" + self._raise_if_closed() + + try: + return next(self) + except StopIteration: + return None + + def fetchmany(self, size=None): + """Fetch the next set of rows of a query result, returning a sequence + of sequences. An empty sequence is returned when no more rows are available. + + :type size: int + :param size: (Optional) The maximum number of results to fetch. + + :raises InterfaceError: + if the previous call to .execute*() did not produce any result set + or if no call was issued yet. + """ + self._raise_if_closed() + + if size is None: + size = self.arraysize + + items = [] + for i in range(size): + try: + items.append(tuple(self.__next__())) + except StopIteration: + break + + return items + + def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ + self._raise_if_closed() + + return list(self.__iter__()) + + def nextset(self): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def setinputsizes(self, sizes): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def setoutputsize(self, size, column=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def _handle_DQL(self, sql, params): + with self.connection.database.snapshot() as snapshot: + # Reference + # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) + res = snapshot.execute_sql( + sql, params=params, param_types=get_param_types(params) + ) + if type(res) == int: + self._row_count = res + self._itr = None + else: + # Immediately using: + # iter(response) + # here, because this Spanner API doesn't provide + # easy mechanisms to detect when only a single item + # is returned or many, yet mixing results that + # are for .fetchone() with those that would result in + # many items returns a RuntimeError if .fetchone() is + # invoked and vice versa. + self._result_set = res + # Read the first element so that the StreamedResultSet can + # return the metadata after a DQL statement. See issue #155. + self._itr = PeekIterator(self._result_set) + # Unfortunately, Spanner doesn't seem to send back + # information about the number of rows available. + self._row_count = _UNSET_COUNT + + def __enter__(self): + return self + + def __exit__(self, etype, value, traceback): + self.close() + + def __next__(self): + if self._itr is None: + raise ProgrammingError("no results to return") + return next(self._itr) + + def __iter__(self): + if self._itr is None: + raise ProgrammingError("no results to return") + return self._itr + + def list_tables(self): + return self.run_sql_in_snapshot(_helpers.SQL_LIST_TABLES) + + def run_sql_in_snapshot(self, sql, params=None, param_types=None): + # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions + # hence this method exists to circumvent that limit. + self.connection.run_prior_DDL_statements() + + with self.connection.database.snapshot() as snapshot: + res = snapshot.execute_sql(sql, params=params, param_types=param_types) + return list(res) + + def get_table_column_schema(self, table_name): + rows = self.run_sql_in_snapshot( + sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, + params={"table_name": table_name}, + param_types={"table_name": spanner.param_types.STRING}, + ) + + column_details = {} + for column_name, is_nullable, spanner_type in rows: + column_details[column_name] = ColumnDetails( + null_ok=is_nullable == "YES", spanner_type=spanner_type + ) + return column_details diff --git a/google/cloud/spanner_dbapi/exceptions.py b/google/cloud/spanner_dbapi/exceptions.py new file mode 100644 index 0000000000..1a9fdd3625 --- /dev/null +++ b/google/cloud/spanner_dbapi/exceptions.py @@ -0,0 +1,102 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Spanner DB API exceptions.""" + + +class Warning(Exception): + """Important DB API warning.""" + + pass + + +class Error(Exception): + """The base class for all the DB API exceptions. + + Does not include :class:`Warning`. + """ + + pass + + +class InterfaceError(Error): + """ + Error related to the database interface + rather than the database itself. + """ + + pass + + +class DatabaseError(Error): + """Error related to the database.""" + + pass + + +class DataError(DatabaseError): + """ + Error due to problems with the processed data like + division by zero, numeric value out of range, etc. + """ + + pass + + +class OperationalError(DatabaseError): + """ + Error related to the database's operation, e.g. an + unexpected disconnect, the data source name is not + found, a transaction could not be processed, a + memory allocation error, etc. + """ + + pass + + +class IntegrityError(DatabaseError): + """ + Error for cases of relational integrity of the database + is affected, e.g. a foreign key check fails. + """ + + pass + + +class InternalError(DatabaseError): + """ + Internal database error, e.g. the cursor is not valid + anymore, the transaction is out of sync, etc. + """ + + pass + + +class ProgrammingError(DatabaseError): + """ + Programming error, e.g. table not found or already + exists, syntax error in the SQL statement, wrong + number of parameters specified, etc. + """ + + pass + + +class NotSupportedError(DatabaseError): + """ + Error for case of a method or database API not + supported by the database was used. + """ + + pass diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py new file mode 100644 index 0000000000..d88dcafb0d --- /dev/null +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -0,0 +1,546 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"SQL parsing and classification utils." + +import datetime +import decimal +import re +from functools import reduce + +import sqlparse +from google.cloud import spanner_v1 as spanner + +from .exceptions import Error, ProgrammingError +from .parser import parse_values +from .types import DateStr, TimestampStr +from .utils import sanitize_literals_for_upload + +TYPES_MAP = { + bool: spanner.param_types.BOOL, + bytes: spanner.param_types.BYTES, + str: spanner.param_types.STRING, + int: spanner.param_types.INT64, + float: spanner.param_types.FLOAT64, + datetime.datetime: spanner.param_types.TIMESTAMP, + datetime.date: spanner.param_types.DATE, + DateStr: spanner.param_types.DATE, + TimestampStr: spanner.param_types.TIMESTAMP, +} + +SPANNER_RESERVED_KEYWORDS = { + "ALL", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "ASSERT_ROWS_MODIFIED", + "AT", + "BETWEEN", + "BY", + "CASE", + "CAST", + "COLLATE", + "CONTAINS", + "CREATE", + "CROSS", + "CUBE", + "CURRENT", + "DEFAULT", + "DEFINE", + "DESC", + "DISTINCT", + "DROP", + "ELSE", + "END", + "ENUM", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXISTS", + "EXTRACT", + "FALSE", + "FETCH", + "FOLLOWING", + "FOR", + "FROM", + "FULL", + "GROUP", + "GROUPING", + "GROUPS", + "HASH", + "HAVING", + "IF", + "IGNORE", + "IN", + "INNER", + "INTERSECT", + "INTERVAL", + "INTO", + "IS", + "JOIN", + "LATERAL", + "LEFT", + "LIKE", + "LIMIT", + "LOOKUP", + "MERGE", + "NATURAL", + "NEW", + "NO", + "NOT", + "NULL", + "NULLS", + "OF", + "ON", + "OR", + "ORDER", + "OUTER", + "OVER", + "PARTITION", + "PRECEDING", + "PROTO", + "RANGE", + "RECURSIVE", + "RESPECT", + "RIGHT", + "ROLLUP", + "ROWS", + "SELECT", + "SET", + "SOME", + "STRUCT", + "TABLESAMPLE", + "THEN", + "TO", + "TREAT", + "TRUE", + "UNBOUNDED", + "UNION", + "UNNEST", + "USING", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHIN", +} + +STMT_DDL = "DDL" +STMT_NON_UPDATING = "NON_UPDATING" +STMT_UPDATING = "UPDATING" +STMT_INSERT = "INSERT" + +# Heuristic for identifying statements that don't need to be run as updates. +RE_NON_UPDATE = re.compile(r"^\s*(SELECT)", re.IGNORECASE) + +RE_WITH = re.compile(r"^\s*(WITH)", re.IGNORECASE) + +# DDL statements follow +# https://cloud.google.com/spanner/docs/data-definition-language +RE_DDL = re.compile(r"^\s*(CREATE|ALTER|DROP)", re.IGNORECASE | re.DOTALL) + +RE_IS_INSERT = re.compile(r"^\s*(INSERT)", re.IGNORECASE | re.DOTALL) + +RE_INSERT = re.compile( + # Only match the `INSERT INTO (columns...) + # otherwise the rest of the statement could be a complex + # operation. + r"^\s*INSERT INTO (?P[^\s\(\)]+)\s*\((?P[^\(\)]+)\)", + re.IGNORECASE | re.DOTALL, +) + +RE_VALUES_TILL_END = re.compile(r"VALUES\s*\(.+$", re.IGNORECASE | re.DOTALL) + +RE_VALUES_PYFORMAT = re.compile( + # To match: (%s, %s,....%s) + r"(\(\s*%s[^\(\)]+\))", + re.DOTALL, +) + +RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL) + + +def classify_stmt(query): + """Determine SQL query type. + + :type query: :class:`str` + :param query: SQL query. + + :rtype: :class:`str` + :returns: Query type name. + """ + if RE_DDL.match(query): + return STMT_DDL + + if RE_IS_INSERT.match(query): + return STMT_INSERT + + if RE_NON_UPDATE.match(query) or RE_WITH.match(query): + # As of 13-March-2020, Cloud Spanner only supports WITH for DQL + # statements and doesn't yet support WITH for DML statements. + return STMT_NON_UPDATING + + return STMT_UPDATING + + +def parse_insert(insert_sql, params): + """ + Parse an INSERT statement an generate a list of tuples of the form: + [ + (SQL, params_per_row1), + (SQL, params_per_row2), + (SQL, params_per_row3), + ... + ] + + There are 4 variants of an INSERT statement: + a) INSERT INTO (columns...) VALUES (): no params + b) INSERT INTO
(columns...) SELECT_STMT: no params + c) INSERT INTO
(columns...) VALUES (%s,...): with params + d) INSERT INTO
(columns...) VALUES (%s,.....) with params and expressions + + Thus given each of the forms, it will produce a dictionary describing + how to upload the contents to Cloud Spanner: + Case a) + SQL: INSERT INTO T (f1, f2) VALUES (1, 2) + it produces: + { + 'sql_params_list': [ + ('INSERT INTO T (f1, f2) VALUES (1, 2)', None), + ], + } + + Case b) + SQL: 'INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', + it produces: + { + 'sql_params_list': [ + ('INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', None), + ] + } + + Case c) + SQL: INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s) + Params: ['a', 'b', 'c', 'd'] + it produces: + { + 'homogenous': True, + 'table': 'T', + 'columns': ['f1', 'f2'], + 'values': [('a', 'b',), ('c', 'd',)], + } + + Case d) + SQL: INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s)), (UPPER(%s), %s) + Params: ['a', 'b', 'c', 'd'] + it produces: + { + 'sql_params_list': [ + ('INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s))', ('a', 'b',)) + ('INSERT INTO T (f1, f2) VALUES (UPPER(%s), %s)', ('c', 'd',)) + ], + } + """ # noqa + match = RE_INSERT.search(insert_sql) + + if not match: + raise ProgrammingError( + "Could not parse an INSERT statement from %s" % insert_sql + ) + + after_values_sql = RE_VALUES_TILL_END.findall(insert_sql) + if not after_values_sql: + # Case b) + insert_sql = sanitize_literals_for_upload(insert_sql) + return {"sql_params_list": [(insert_sql, None)]} + + if not params: + # Case a) perhaps? + # Check if any %s exists. + + # pyformat_str_count = after_values_sql.count("%s") + # if pyformat_str_count > 0: + # raise ProgrammingError( + # 'no params yet there are %d "%%s" tokens' % pyformat_str_count + # ) + for item in after_values_sql: + if item.count("%s") > 0: + raise ProgrammingError( + 'no params yet there are %d "%%s" tokens' % item.count("%s") + ) + + insert_sql = sanitize_literals_for_upload(insert_sql) + # Confirmed case of: + # SQL: INSERT INTO T (a1, a2) VALUES (1, 2) + # Params: None + return {"sql_params_list": [(insert_sql, None)]} + + values_str = after_values_sql[0] + _, values = parse_values(values_str) + + if values.homogenous(): + # Case c) + + columns = [mi.strip(" `") for mi in match.group("columns").split(",")] + sql_params_list = [] + insert_sql_preamble = "INSERT INTO %s (%s) VALUES %s" % ( + match.group("table_name"), + match.group("columns"), + values.argv[0], + ) + values_pyformat = [str(arg) for arg in values.argv] + rows_list = rows_for_insert_or_update(columns, params, values_pyformat) + insert_sql_preamble = sanitize_literals_for_upload(insert_sql_preamble) + for row in rows_list: + sql_params_list.append((insert_sql_preamble, row)) + + return {"sql_params_list": sql_params_list} + + # Case d) + # insert_sql is of the form: + # INSERT INTO T(c1, c2) VALUES (%s, %s), (%s, LOWER(%s)) + + # Sanity check: + # length(all_args) == len(params) + args_len = reduce(lambda a, b: a + b, [len(arg) for arg in values.argv]) + if args_len != len(params): + raise ProgrammingError( + "Invalid length: VALUES(...) len: %d != len(params): %d" + % (args_len, len(params)) + ) + + trim_index = insert_sql.find(values_str) + before_values_sql = insert_sql[:trim_index] + + sql_param_tuples = [] + for token_arg in values.argv: + row_sql = before_values_sql + " VALUES%s" % token_arg + row_sql = sanitize_literals_for_upload(row_sql) + row_params, params = ( + tuple(params[0 : len(token_arg)]), + params[len(token_arg) :], + ) + sql_param_tuples.append((row_sql, row_params)) + + return {"sql_params_list": sql_param_tuples} + + +def rows_for_insert_or_update(columns, params, pyformat_args=None): + """ + Create a tupled list of params to be used as a single value per + value that inserted from a statement such as + SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s)' + Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9] + + We'll have to convert both params types into: + Params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] + """ # noqa + + if not pyformat_args: + # This is the case where we have for example: + # SQL: 'INSERT INTO t (f1, f2, f3)' + # Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + # Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9] + # + # We'll have to convert both params types into: + # [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] + contains_all_list_or_tuples = True + for param in params: + if not (isinstance(param, list) or isinstance(param, tuple)): + contains_all_list_or_tuples = False + break + + if contains_all_list_or_tuples: + # The case with Params A: [(1, 2, 3), (4, 5, 6)] + # Ensure that each param's length == len(columns) + columns_len = len(columns) + for param in params: + if columns_len != len(param): + raise Error( + "\nlen(`%s`)=%d\n!=\ncolum_len(`%s`)=%d" + % (param, len(param), columns, columns_len) + ) + return params + else: + # The case with Params B: [1, 2, 3] + # Insert statements' params are only passed as tuples or lists, + # yet for do_execute_update, we've got to pass in list of list. + # https://googleapis.dev/python/spanner/latest/transaction-api.html\ + # #google.cloud.spanner_v1.transaction.Transaction.insert + n_stride = len(columns) + else: + # This is the case where we have for example: + # SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), + # (%s, %s, %s), (%s, %s, %s)' + # Params: [1, 2, 3, 4, 5, 6, 7, 8, 9] + # which should become + # Columns: (f1, f2, f3) + # new_params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] + + # Sanity check 1: all the pyformat_values should have the exact same + # length. + first, rest = pyformat_args[0], pyformat_args[1:] + n_stride = first.count("%s") + for pyfmt_value in rest: + n = pyfmt_value.count("%s") + if n_stride != n: + raise Error( + "\nlen(`%s`)=%d\n!=\nlen(`%s`)=%d" + % (first, n_stride, pyfmt_value, n) + ) + + # Sanity check 2: len(params) MUST be a multiple of n_stride aka + # len(count of %s). + # so that we can properly group for example: + # Given pyformat args: + # (%s, %s, %s) + # Params: + # [1, 2, 3, 4, 5, 6, 7, 8, 9] + # into + # [(1, 2, 3), (4, 5, 6), (7, 8, 9)] + if (len(params) % n_stride) != 0: + raise ProgrammingError( + "Invalid length: len(params)=%d MUST be a multiple of " + "len(pyformat_args)=%d" % (len(params), n_stride) + ) + + # Now chop up the strides. + strides = [] + for step in range(0, len(params), n_stride): + stride = tuple(params[step : step + n_stride :]) + strides.append(stride) + + return strides + + +def sql_pyformat_args_to_spanner(sql, params): + """ + Transform pyformat set SQL to named arguments for Cloud Spanner. + It will also unescape previously escaped format specifiers + like %%s to %s. + For example: + SQL: 'SELECT * from t where f1=%s, f2=%s, f3=%s' + Params: ('a', 23, '888***') + becomes: + SQL: 'SELECT * from t where f1=@a0, f2=@a1, f3=@a2' + Params: {'a0': 'a', 'a1': 23, 'a2': '888***'} + + OR + SQL: 'SELECT * from t where f1=%(f1)s, f2=%(f2)s, f3=%(f3)s' + Params: {'f1': 'a', 'f2': 23, 'f3': '888***', 'extra': 'aye') + becomes: + SQL: 'SELECT * from t where f1=@a0, f2=@a1, f3=@a2' + Params: {'a0': 'a', 'a1': 23, 'a2': '888***'} + """ + if not params: + return sanitize_literals_for_upload(sql), params + + found_pyformat_placeholders = RE_PYFORMAT.findall(sql) + params_is_dict = isinstance(params, dict) + + if params_is_dict: + if not found_pyformat_placeholders: + return sanitize_literals_for_upload(sql), params + else: + n_params = len(params) if params else 0 + n_matches = len(found_pyformat_placeholders) + if n_matches != n_params: + raise Error( + "pyformat_args mismatch\ngot %d args from %s\n" + "want %d args in %s" + % (n_matches, found_pyformat_placeholders, n_params, params) + ) + + named_args = {} + # We've now got for example: + # Case a) Params is a non-dict + # SQL: 'SELECT * from t where f1=%s, f2=%s, f3=%s' + # Params: ('a', 23, '888***') + # Case b) Params is a dict and the matches are %(value)s' + for i, pyfmt in enumerate(found_pyformat_placeholders): + key = "a%d" % i + sql = sql.replace(pyfmt, "@" + key, 1) + if params_is_dict: + # The '%(key)s' case, so interpolate it. + resolved_value = pyfmt % params + named_args[key] = resolved_value + else: + named_args[key] = cast_for_spanner(params[i]) + + return sanitize_literals_for_upload(sql), named_args + + +def cast_for_spanner(value): + """Convert the param to its Cloud Spanner equivalent type. + + :type value: Any + :param value: Value to convert to a Cloud Spanner type. + + :rtype: Any + :returns: Value converted to a Cloud Spanner type. + """ + if isinstance(value, decimal.Decimal): + return str(value) + return value + + +def get_param_types(params): + """Determine Cloud Spanner types for the given parameters. + + :type params: :class:`dict` + :param params: Parameters requiring to find Cloud Spanner types. + + :rtype: :class:`dict` + :returns: The types index for the given parameters. + """ + if params is None: + return + + param_types = {} + + for key, value in params.items(): + type_ = type(value) + if type_ in TYPES_MAP: + param_types[key] = TYPES_MAP[type_] + + return param_types + + +def ensure_where_clause(sql): + """ + Cloud Spanner requires a WHERE clause on UPDATE and DELETE statements. + Add a dummy WHERE clause if necessary. + """ + if any(isinstance(token, sqlparse.sql.Where) for token in sqlparse.parse(sql)[0]): + return sql + return sql + " WHERE 1=1" + + +def escape_name(name): + """ + Apply backticks to the name that either contain '-' or + ' ', or is a Cloud Spanner's reserved keyword. + + :type name: :class:`str` + :param name: Name to escape. + + :rtype: :class:`str` + :returns: Name escaped if it has to be escaped. + """ + if "-" in name or " " in name or name.upper() in SPANNER_RESERVED_KEYWORDS: + return "`" + name + "`" + return name diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py new file mode 100644 index 0000000000..9271631b25 --- /dev/null +++ b/google/cloud/spanner_dbapi/parser.py @@ -0,0 +1,246 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Grammar for parsing VALUES: + VALUES := `VALUES(` + ARGS + `)` + ARGS := [EXPR,]*EXPR + EXPR := TERMINAL / FUNC + TERMINAL := `%s` + FUNC := alphanum + `(` + ARGS + `)` + alphanum := (a-zA-Z_)[0-9a-ZA-Z_]* + +thus given: + statement: 'VALUES (%s, %s), (%s, LOWER(UPPER(%s))) , (%s)' + It'll parse: + VALUES + |- ARGS + |- (TERMINAL, TERMINAL) + |- (TERMINAL, FUNC + |- FUNC + |- (TERMINAL) + |- (TERMINAL) +""" + +from .exceptions import ProgrammingError + +ARGS = "ARGS" +EXPR = "EXPR" +FUNC = "FUNC" +VALUES = "VALUES" + + +class func(object): + def __init__(self, func_name, args): + self.name = func_name + self.args = args + + def __str__(self): + return "%s%s" % (self.name, self.args) + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + if type(self) != type(other): + return False + if self.name != other.name: + return False + if not isinstance(other.args, type(self.args)): + return False + if len(self.args) != len(other.args): + return False + return self.args == other.args + + def __len__(self): + return len(self.args) + + +class terminal(str): + """ + terminal represents the unit symbol that can be part of a SQL values clause. + """ + + pass + + +class a_args(object): + def __init__(self, argv): + self.argv = argv + + def __str__(self): + return "(" + ", ".join([str(arg) for arg in self.argv]) + ")" + + def __repr__(self): + return self.__str__() + + def has_expr(self): + return any([token for token in self.argv if not isinstance(token, terminal)]) + + def __len__(self): + return len(self.argv) + + def __eq__(self, other): + if type(self) != type(other): + return False + + if len(self) != len(other): + return False + + for i, item in enumerate(self): + if item != other[i]: + return False + + return True + + def __getitem__(self, index): + return self.argv[index] + + def homogenous(self): + """ + Return True if all the arguments are pyformat + args and have the same number of arguments. + """ + if not self._is_equal_length(): + return False + + for arg in self.argv: + if isinstance(arg, terminal): + continue + elif isinstance(arg, a_args): + if not arg.homogenous(): + return False + else: + return False + return True + + def _is_equal_length(self): + """ + Return False if all the arguments have the same length. + """ + if len(self) == 0: + return True + + arg0_len = len(self.argv[0]) + for arg in self.argv[1:]: + if len(arg) != arg0_len: + return False + + return True + + +class values(a_args): + def __str__(self): + return "VALUES%s" % super().__str__() + + +def parse_values(stmt): + return expect(stmt, VALUES) + + +pyfmt_str = terminal("%s") + + +def expect(word, token): + word = word.strip() + if token == VALUES: + if not word.startswith("VALUES"): + raise ProgrammingError("VALUES: `%s` does not start with VALUES" % word) + word = word[len("VALUES") :].lstrip() + + all_args = [] + while word: + word = word.strip() + + word, arg = expect(word, ARGS) + all_args.append(arg) + word = word.strip() + + if word and not word.startswith(","): + raise ProgrammingError( + "VALUES: expected `,` got %s in %s" % (word[0], word) + ) + word = word[1:] + return "", values(all_args) + + elif token == FUNC: + begins_with_letter = word and (word[0].isalpha() or word[0] == "_") + if not begins_with_letter: + raise ProgrammingError( + "FUNC: `%s` does not begin with `a-zA-z` nor a `_`" % word + ) + + rest = word[1:] + end = 0 + for ch in rest: + if ch.isalnum() or ch == "_": + end += 1 + else: + break + + func_name, rest = word[: end + 1], word[end + 1 :].strip() + + word, args = expect(rest, ARGS) + return word, func(func_name, args) + + elif token == ARGS: + # The form should be: + # (%s) + # (%s, %s...) + # (FUNC, %s...) + # (%s, %s...) + if not (word and word.startswith("(")): + raise ProgrammingError("ARGS: supposed to begin with `(` in `%s`" % word) + + word = word[1:] + + terms = [] + while True: + word = word.strip() + if not word or word.startswith(")"): + break + + if word == "%s": + terms.append(pyfmt_str) + word = "" + elif not word.startswith("%s"): + word, parsed = expect(word, FUNC) + terms.append(parsed) + else: + terms.append(pyfmt_str) + word = word[2:].strip() + + if word.startswith(","): + word = word[1:] + + if not (word and word.startswith(")")): + raise ProgrammingError("ARGS: supposed to end with `)` in `%s`" % word) + + word = word[1:] + return word, a_args(terms) + + elif token == EXPR: + if word == "%s": + # Terminal symbol. + return "", pyfmt_str + + # Otherwise we expect a function. + return expect(word, FUNC) + + raise ProgrammingError("Unknown token `%s`" % token) + + +def as_values(values_stmt): + _, _values = parse_values(values_stmt) + return _values diff --git a/google/cloud/spanner_dbapi/types.py b/google/cloud/spanner_dbapi/types.py new file mode 100644 index 0000000000..80d7030402 --- /dev/null +++ b/google/cloud/spanner_dbapi/types.py @@ -0,0 +1,106 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of the type objects and constructors according to the + PEP-0249 specification. + + See + https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors +""" + +import datetime +import time +from base64 import b64encode + + +def _date_from_ticks(ticks): + """Based on PEP-249 Implementation Hints for Module Authors: + + https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors + """ + return Date(*time.localtime(ticks)[:3]) + + +def _time_from_ticks(ticks): + """Based on PEP-249 Implementation Hints for Module Authors: + + https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors + """ + return Time(*time.localtime(ticks)[3:6]) + + +def _timestamp_from_ticks(ticks): + """Based on PEP-249 Implementation Hints for Module Authors: + + https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors + """ + return Timestamp(*time.localtime(ticks)[:6]) + + +class _DBAPITypeObject(object): + """Implementation of a helper class used for type comparison among similar + but possibly different types. + + See + https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors + """ + + def __init__(self, *values): + self.values = values + + def __eq__(self, other): + return other in self.values + + +Date = datetime.date +Time = datetime.time +Timestamp = datetime.datetime +DateFromTicks = _date_from_ticks +TimeFromTicks = _time_from_ticks +TimestampFromTicks = _timestamp_from_ticks +Binary = b64encode + +STRING = "STRING" +BINARY = _DBAPITypeObject("TYPE_CODE_UNSPECIFIED", "BYTES", "ARRAY", "STRUCT") +NUMBER = _DBAPITypeObject("BOOL", "INT64", "FLOAT64", "NUMERIC") +DATETIME = _DBAPITypeObject("TIMESTAMP", "DATE") +ROWID = "STRING" + + +class TimestampStr(str): + """[inherited from the alpha release] + + TODO: Decide whether this class is necessary + + TimestampStr exists so that we can purposefully format types as timestamps + compatible with Cloud Spanner's TIMESTAMP type, but right before making + queries, it'll help differentiate between normal strings and the case of + types that should be TIMESTAMP. + """ + + pass + + +class DateStr(str): + """[inherited from the alpha release] + + TODO: Decide whether this class is necessary + + DateStr is a sentinel type to help format Django dates as + compatible with Cloud Spanner's DATE type, but right before making + queries, it'll help differentiate between normal strings and the case of + types that should be DATE. + """ + + pass diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py new file mode 100644 index 0000000000..b0ad3922a5 --- /dev/null +++ b/google/cloud/spanner_dbapi/utils.py @@ -0,0 +1,89 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +class PeekIterator: + """ + PeekIterator peeks at the first element out of an iterator + for the sake of operations like auto-population of fields on reading + the first element. + If next's result is an instance of list, it'll be converted into a tuple + to conform with DBAPI v2's sequence expectations. + """ + + def __init__(self, source): + itr_src = iter(source) + + self.__iters = [] + self.__index = 0 + + try: + head = next(itr_src) + # Restitch and prepare to read from multiple iterators. + self.__iters = [iter(itr) for itr in [[head], itr_src]] + except StopIteration: + pass + + def __next__(self): + if self.__index >= len(self.__iters): + raise StopIteration + + iterator = self.__iters[self.__index] + try: + head = next(iterator) + except StopIteration: + # That iterator has been exhausted, try with the next one. + self.__index += 1 + return self.__next__() + else: + return tuple(head) if isinstance(head, list) else head + + def __iter__(self): + return self + + +re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") + + +def backtick_unicode(sql): + matches = list(re_UNICODE_POINTS.finditer(sql)) + if not matches: + return sql + + segments = [] + + last_end = 0 + for match in matches: + start, end = match.span() + if sql[start] != "`" and sql[end - 1] != "`": + segments.append(sql[last_end:start] + "`" + sql[start:end] + "`") + else: + segments.append(sql[last_end:end]) + + last_end = end + + return "".join(segments) + + +def sanitize_literals_for_upload(s): + """ + Convert literals in s, to be fit for consumption by Cloud Spanner. + 1. Convert %% (escaped percent literals) to %. Percent signs must be escaped when + values like %s are used as SQL parameter placeholders but Spanner's query language + uses placeholders like @a0 and doesn't expect percent signs to be escaped. + 2. Quote words containing non-ASCII, with backticks, for example föö to `föö`. + """ + return backtick_unicode(s.replace("%%", "%")) diff --git a/google/cloud/spanner_dbapi/version.py b/google/cloud/spanner_dbapi/version.py new file mode 100644 index 0000000000..b0e48cff0b --- /dev/null +++ b/google/cloud/spanner_dbapi/version.py @@ -0,0 +1,19 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + +PY_VERSION = platform.python_version() +VERSION = "2.2.0a1" +DEFAULT_USER_AGENT = "django_spanner/" + VERSION diff --git a/noxfile.py b/noxfile.py index 1a6227824a..47a9ee3803 100644 --- a/noxfile.py +++ b/noxfile.py @@ -72,7 +72,7 @@ def default(session): # Install all test dependencies, then install this package in-place. session.install("asyncmock", "pytest-asyncio") - session.install("mock", "pytest", "pytest-cov") + session.install("mock", "pytest", "pytest-cov", "sqlparse") session.install("-e", ".") # Run py.test against the unit tests. diff --git a/tests/unit/spanner_dbapi/__init__.py b/tests/unit/spanner_dbapi/__init__.py new file mode 100644 index 0000000000..377df12f71 --- /dev/null +++ b/tests/unit/spanner_dbapi/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py new file mode 100644 index 0000000000..84d6b3e323 --- /dev/null +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -0,0 +1,119 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Spanner DB-API Connection class unit tests.""" + +import mock +import unittest + + +class TestHelpers(unittest.TestCase): + def test__execute_insert_heterogenous(self): + from google.cloud.spanner_dbapi import _helpers + + sql = "sql" + params = (sql, None) + with mock.patch( + "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner", + return_value=params, + ) as mock_pyformat: + with mock.patch( + "google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None + ) as mock_param_types: + transaction = mock.MagicMock() + transaction.execute_update = mock_execute = mock.MagicMock() + _helpers._execute_insert_heterogenous(transaction, [params]) + + mock_pyformat.assert_called_once_with(params[0], params[1]) + mock_param_types.assert_called_once_with(None) + mock_execute.assert_called_once_with(sql, params=None, param_types=None) + + def test__execute_insert_homogenous(self): + from google.cloud.spanner_dbapi import _helpers + + transaction = mock.MagicMock() + transaction.insert = mock.MagicMock() + parts = mock.MagicMock() + parts.get = mock.MagicMock(return_value=0) + + _helpers._execute_insert_homogenous(transaction, parts) + transaction.insert.assert_called_once_with(0, 0, 0) + + def test_handle_insert(self): + from google.cloud.spanner_dbapi import _helpers + + connection = mock.MagicMock() + connection.database.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + parts = mock.MagicMock() + with mock.patch( + "google.cloud.spanner_dbapi._helpers.parse_insert", return_value=parts + ): + parts.get = mock.MagicMock(return_value=True) + mock_run_in.return_value = 0 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 0) + + parts.get = mock.MagicMock(return_value=False) + mock_run_in.return_value = 1 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 1) + + +class TestColumnInfo(unittest.TestCase): + def test_ctor(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + name = "col-name" + type_code = 8 + display_size = 5 + internal_size = 10 + precision = 3 + scale = None + null_ok = False + + cols = ColumnInfo( + name, type_code, display_size, internal_size, precision, scale, null_ok + ) + + self.assertEqual(cols.name, name) + self.assertEqual(cols.type_code, type_code) + self.assertEqual(cols.display_size, display_size) + self.assertEqual(cols.internal_size, internal_size) + self.assertEqual(cols.precision, precision) + self.assertEqual(cols.scale, scale) + self.assertEqual(cols.null_ok, null_ok) + self.assertEqual( + cols.fields, + (name, type_code, display_size, internal_size, precision, scale, null_ok), + ) + + def test___get_item__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + fields = ("col-name", 8, 5, 10, 3, None, False) + cols = ColumnInfo(*fields) + + for i in range(0, 7): + self.assertEqual(cols[i], fields[i]) + + def test___str__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) + + self.assertEqual( + str(cols), + "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", + ) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py new file mode 100644 index 0000000000..8cd3bced16 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -0,0 +1,308 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Spanner DB-API Connection class unit tests.""" + +import mock +import sys +import unittest +import warnings + + +def _make_credentials(): + from google.auth import credentials + + class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class TestConnection(unittest.TestCase): + + PROJECT = "test-project" + INSTANCE = "test-instance" + DATABASE = "test-database" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + + def _get_client_info(self): + from google.api_core.gapic_v1.client_info import ClientInfo + + return ClientInfo(user_agent=self.USER_AGENT) + + def _make_connection(self): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_v1.instance import Instance + + # We don't need a real Client object to test the constructor + instance = Instance(self.INSTANCE, client=None) + database = instance.database(self.DATABASE) + return Connection(instance, database) + + @unittest.skipIf(sys.version_info[0] < 3, "Python 2 patching is outdated") + def test_property_autocommit_setter(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = True + mock_commit.assert_called_once_with() + self.assertEqual(connection._autocommit, True) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = False + mock_commit.assert_not_called() + self.assertEqual(connection._autocommit, False) + + def test_property_database(self): + from google.cloud.spanner_v1.database import Database + + connection = self._make_connection() + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database, connection._database) + + def test_property_instance(self): + from google.cloud.spanner_v1.instance import Instance + + connection = self._make_connection() + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance, connection._instance) + + def test__session_checkout(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch("google.cloud.spanner_v1.database.Database") as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.get = mock.MagicMock(return_value="db_session_pool") + connection = Connection(self.INSTANCE, mock_database) + + connection._session_checkout() + mock_database._pool.get.assert_called_once_with() + self.assertEqual(connection._session, "db_session_pool") + + connection._session = "db_session" + connection._session_checkout() + self.assertEqual(connection._session, "db_session") + + def test__release_session(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch("google.cloud.spanner_v1.database.Database") as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.put = mock.MagicMock() + connection = Connection(self.INSTANCE, mock_database) + connection._session = "session" + + connection._release_session() + mock_database._pool.put.assert_called_once_with("session") + self.assertIsNone(connection._session) + + def test_transaction_checkout(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + connection._session_checkout = mock_checkout = mock.MagicMock(autospec=True) + connection.transaction_checkout() + mock_checkout.assert_called_once_with() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + self.assertEqual(connection.transaction_checkout(), mock_transaction) + + connection._autocommit = True + self.assertIsNone(connection.transaction_checkout()) + + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + self.assertFalse(connection.is_closed) + connection.close() + self.assertTrue(connection.is_closed) + + with self.assertRaises(InterfaceError): + connection.cursor() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + mock_transaction.rollback = mock_rollback = mock.MagicMock() + connection.close() + mock_rollback.assert_called_once_with() + + @mock.patch.object(warnings, "warn") + def test_commit(self, mock_warn): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.commit() + mock_release.assert_not_called() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.commit = mock_commit = mock.MagicMock() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.commit() + mock_commit.assert_called_once_with() + mock_release.assert_called_once_with() + + connection._autocommit = True + connection.commit() + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + + @mock.patch.object(warnings, "warn") + def test_rollback(self, mock_warn): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.rollback() + mock_release.assert_not_called() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.rollback = mock_rollback = mock.MagicMock() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.rollback() + mock_rollback.assert_called_once_with() + mock_release.assert_called_once_with() + + connection._autocommit = True + connection.rollback() + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + + def test_run_prior_DDL_statements(self): + from google.cloud.spanner_dbapi import Connection, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.database.Database", autospec=True + ) as mock_database: + connection = Connection(self.INSTANCE, mock_database) + + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_not_called() + + ddl = ["ddl"] + connection._ddl_statements = ddl + + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_called_once_with(ddl) + + connection.is_closed = True + + with self.assertRaises(InterfaceError): + connection.run_prior_DDL_statements() + + def test_context(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + with connection as conn: + self.assertEqual(conn, connection) + + self.assertTrue(connection.is_closed) + + def test_connect(self): + from google.cloud.spanner_dbapi import Connection, connect + + with mock.patch("google.cloud.spanner_v1.Client"): + with mock.patch( + "google.api_core.gapic_v1.client_info.ClientInfo", + return_value=self._get_client_info(), + ): + connection = connect( + self.INSTANCE, + self.DATABASE, + self.PROJECT, + self.CREDENTIALS, + self.USER_AGENT, + ) + self.assertIsInstance(connection, Connection) + + def test_connect_instance_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=False + ): + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + def test_connect_database_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=False + ): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + def test_default_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + self.assertIsNotNone(connection.database._pool) + + def test_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.pool import FixedSizePool + + database_id = "test-database" + pool = FixedSizePool() + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + connect("test-instance", database_id, pool=pool) + database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py new file mode 100644 index 0000000000..23ed5010d1 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -0,0 +1,455 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cursor() class unit tests.""" + +import mock +import sys +import unittest + + +class TestCursor(unittest.TestCase): + + INSTANCE = "test-instance" + DATABASE = "test-database" + + def _get_target_class(self): + from google.cloud.spanner_dbapi import Cursor + + return Cursor + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def _make_connection(self, *args, **kwargs): + from google.cloud.spanner_dbapi import Connection + + return Connection(*args, **kwargs) + + def test_property_connection(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + self.assertEqual(cursor.connection, connection) + + def test_property_description(self): + from google.cloud.spanner_dbapi._helpers import ColumnInfo + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + self.assertIsNone(cursor.description) + cursor._result_set = res_set = mock.MagicMock() + res_set.metadata.row_type.fields = [mock.MagicMock()] + self.assertIsNotNone(cursor.description) + self.assertIsInstance(cursor.description[0], ColumnInfo) + + def test_property_rowcount(self): + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + self.assertEqual(cursor.rowcount, _UNSET_COUNT) + + def test_callproc(self): + from google.cloud.spanner_dbapi.exceptions import InterfaceError + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + cursor._is_closed = True + with self.assertRaises(InterfaceError): + cursor.callproc(procname=None) + + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect(self.INSTANCE, self.DATABASE) + + cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + + cursor.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + def test_do_execute_update(self): + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + transaction = mock.MagicMock() + + def run_helper(ret_value): + transaction.execute_update.return_value = ret_value + res = cursor._do_execute_update( + transaction=transaction, sql="sql", params=None + ) + return res + + expected = "good" + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + + expected = 1234 + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, expected) + + def test_execute_programming_error(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + cursor.connection = None + with self.assertRaises(ProgrammingError): + cursor.execute(sql="") + + def test_execute_attribute_error(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + with self.assertRaises(AttributeError): + cursor.execute(sql="") + + def test_execute_autocommit_off(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.connection._autocommit = False + cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) + + cursor.execute("sql") + self.assertIsInstance(cursor._result_set, mock.MagicMock) + self.assertIsInstance(cursor._itr, PeekIterator) + + def test_execute_statement(self): + from google.cloud.spanner_dbapi import parse_utils + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_DDL, + ) as mock_classify_stmt: + sql = "sql" + cursor.execute(sql=sql) + mock_classify_stmt.assert_called_once_with(sql) + self.assertEqual(cursor.connection._ddl_statements, [sql]) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_NON_UPDATING, + ): + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", + return_value=parse_utils.STMT_NON_UPDATING, + ) as mock_handle_ddl: + connection.autocommit = True + sql = "sql" + cursor.execute(sql=sql) + mock_handle_ddl.assert_called_once_with(sql, None) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_INSERT, + ): + with mock.patch( + "google.cloud.spanner_dbapi._helpers.handle_insert", + return_value=parse_utils.STMT_INSERT, + ) as mock_handle_insert: + sql = "sql" + cursor.execute(sql=sql) + mock_handle_insert.assert_called_once_with(connection, sql, None) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value="other_statement", + ): + cursor.connection._database = mock_db = mock.MagicMock() + mock_db.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + cursor.execute(sql=sql) + mock_run_in.assert_called_once_with(cursor._do_execute_update, sql, None) + + def test_execute_integrity_error(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import IntegrityError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.AlreadyExists("message"), + ): + with self.assertRaises(IntegrityError): + cursor.execute(sql="sql") + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.FailedPrecondition("message"), + ): + with self.assertRaises(IntegrityError): + cursor.execute(sql="sql") + + def test_execute_invalid_argument(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.InvalidArgument("message"), + ): + with self.assertRaises(ProgrammingError): + cursor.execute(sql="sql") + + def test_execute_internal_server_error(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import OperationalError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.InternalServerError("message"), + ): + with self.assertRaises(OperationalError): + cursor.execute(sql="sql") + + def test_executemany_on_closed_cursor(self): + from google.cloud.spanner_dbapi import InterfaceError + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor.close() + + with self.assertRaises(InterfaceError): + cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ()) + + def test_executemany(self): + from google.cloud.spanner_dbapi import connect + + operation = """SELECT * FROM table1 WHERE "col1" = @a1""" + params_seq = ((1,), (2,)) + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.execute" + ) as execute_mock: + cursor.executemany(operation, params_seq) + + execute_mock.assert_has_calls( + (mock.call(operation, (1,)), mock.call(operation, (2,))) + ) + + @unittest.skipIf( + sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" + ) + def test_fetchone(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [1, 2, 3] + cursor._itr = iter(lst) + for i in range(len(lst)): + self.assertEqual(cursor.fetchone(), lst[i]) + self.assertIsNone(cursor.fetchone()) + + def test_fetchmany(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + + self.assertEqual(cursor.fetchmany(), [lst[0]]) + + result = cursor.fetchmany(len(lst)) + self.assertEqual(result, lst[1:]) + + def test_fetchall(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + self.assertEqual(cursor.fetchall(), lst) + + def test_nextset(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.nextset() + + def test_setinputsizes(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.setinputsizes(sizes=None) + + def test_setoutputsize(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.setoutputsize(size=None) + + # def test_handle_insert(self): + # pass + # + # def test_do_execute_insert_heterogenous(self): + # pass + # + # def test_do_execute_insert_homogenous(self): + # pass + + def test_handle_dql(self): + from google.cloud.spanner_dbapi import utils + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + cursor = self._make_one(connection) + + mock_snapshot.execute_sql.return_value = int(0) + cursor._handle_DQL("sql", params=None) + self.assertEqual(cursor._row_count, 0) + self.assertIsNone(cursor._itr) + + mock_snapshot.execute_sql.return_value = "0" + cursor._handle_DQL("sql", params=None) + self.assertEqual(cursor._result_set, "0") + self.assertIsInstance(cursor._itr, utils.PeekIterator) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + + def test_context(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with cursor as c: + self.assertEqual(c, cursor) + + self.assertTrue(c.is_closed) + + def test_next(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with self.assertRaises(exceptions.ProgrammingError): + cursor.__next__() + + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + i = 0 + for c in cursor._itr: + self.assertEqual(c, lst[i]) + i += 1 + + def test_iter(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with self.assertRaises(exceptions.ProgrammingError): + _ = iter(cursor) + + iterator = iter([(1,), (2,), (3,)]) + cursor._itr = iterator + self.assertEqual(iter(cursor), iterator) + + def test_list_tables(self): + from google.cloud.spanner_dbapi import _helpers + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + table_list = ["table1", "table2", "table3"] + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", + return_value=table_list, + ) as mock_run_sql: + cursor.list_tables() + mock_run_sql.assert_called_once_with(_helpers.SQL_LIST_TABLES) + + def test_run_sql_in_snapshot(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + cursor = self._make_one(connection) + + results = 1, 2, 3 + mock_snapshot.execute_sql.return_value = results + self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results)) + + def test_get_table_column_schema(self): + from google.cloud.spanner_dbapi.cursor import ColumnDetails + from google.cloud.spanner_dbapi import _helpers + from google.cloud.spanner_v1 import param_types + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + column_name = "column1" + is_nullable = "YES" + spanner_type = "spanner_type" + rows = [(column_name, is_nullable, spanner_type)] + expected = {column_name: ColumnDetails(null_ok=True, spanner_type=spanner_type)} + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", + return_value=rows, + ) as mock_run_sql: + table_name = "table1" + result = cursor.get_table_column_schema(table_name=table_name) + mock_run_sql.assert_called_once_with( + sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, + params={"table_name": table_name}, + param_types={"table_name": param_types.STRING}, + ) + self.assertEqual(result, expected) diff --git a/tests/unit/spanner_dbapi/test_globals.py b/tests/unit/spanner_dbapi/test_globals.py new file mode 100644 index 0000000000..2960862ec3 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_globals.py @@ -0,0 +1,28 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + + +class TestDBAPIGlobals(unittest.TestCase): + def test_apilevel(self): + from google.cloud.spanner_dbapi import apilevel + from google.cloud.spanner_dbapi import paramstyle + from google.cloud.spanner_dbapi import threadsafety + + self.assertEqual(apilevel, "2.0", "We implement PEP-0249 version 2.0") + self.assertEqual(paramstyle, "format", "Cloud Spanner uses @param") + self.assertEqual( + threadsafety, 1, "Threads may share module but not connections" + ) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py new file mode 100644 index 0000000000..a79ad8dc51 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -0,0 +1,439 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +from google.cloud.spanner_v1 import param_types + + +class TestParseUtils(unittest.TestCase): + + skip_condition = sys.version_info[0] < 3 + skip_message = "Subtests are not supported in Python 2" + + def test_classify_stmt(self): + from google.cloud.spanner_dbapi.parse_utils import STMT_DDL + from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT + from google.cloud.spanner_dbapi.parse_utils import STMT_NON_UPDATING + from google.cloud.spanner_dbapi.parse_utils import STMT_UPDATING + from google.cloud.spanner_dbapi.parse_utils import classify_stmt + + cases = ( + ("SELECT 1", STMT_NON_UPDATING), + ("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING), + ( + "WITH sq AS (SELECT SchoolID FROM Roster) SELECT * from sq", + STMT_NON_UPDATING, + ), + ( + "CREATE TABLE django_content_type (id STRING(64) NOT NULL, name STRING(100) " + "NOT NULL, app_label STRING(100) NOT NULL, model STRING(100) NOT NULL) PRIMARY KEY(id)", + STMT_DDL, + ), + ( + "CREATE INDEX SongsBySingerAlbumSongNameDesc ON " + "Songs(SingerId, AlbumId, SongName DESC), INTERLEAVE IN Albums", + STMT_DDL, + ), + ("CREATE INDEX SongsBySongName ON Songs(SongName)", STMT_DDL), + ( + "CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)", + STMT_DDL, + ), + ("INSERT INTO table (col1) VALUES (1)", STMT_INSERT), + ("UPDATE table SET col1 = 1 WHERE col1 = NULL", STMT_UPDATING), + ) + + for query, want_class in cases: + self.assertEqual(classify_stmt(query), want_class) + + @unittest.skipIf(skip_condition, skip_message) + def test_parse_insert(self): + from google.cloud.spanner_dbapi.parse_utils import parse_insert + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + with self.assertRaises(ProgrammingError): + parse_insert("bad-sql", None) + + cases = [ + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + [1, 2, 3, 4, 5, 6], + { + "sql_params_list": [ + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (1, 2, 3), + ), + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (4, 5, 6), + ), + ] + }, + ), + ( + "INSERT INTO django_migrations(app, name, applied) VALUES (%s, %s, %s)", + [1, 2, 3, 4, 5, 6], + { + "sql_params_list": [ + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (1, 2, 3), + ), + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", + (4, 5, 6), + ), + ] + }, + ), + ( + "INSERT INTO sales.addresses (street, city, state, zip_code) " + "SELECT street, city, state, zip_code FROM sales.customers" + "ORDER BY first_name, last_name", + None, + { + "sql_params_list": [ + ( + "INSERT INTO sales.addresses (street, city, state, zip_code) " + "SELECT street, city, state, zip_code FROM sales.customers" + "ORDER BY first_name, last_name", + None, + ) + ] + }, + ), + ( + "INSERT INTO ap (n, ct, cn) " + "VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s),(%s, %s, %s)", + (1, 2, 3, 4, 5, 6, 7, 8, 9), + { + "sql_params_list": [ + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (1, 2, 3)), + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (4, 5, 6)), + ("INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", (7, 8, 9)), + ] + }, + ), + ( + "INSERT INTO `no` (`yes`) VALUES (%s)", + (1, 4, 5), + { + "sql_params_list": [ + ("INSERT INTO `no` (`yes`) VALUES (%s)", (1,)), + ("INSERT INTO `no` (`yes`) VALUES (%s)", (4,)), + ("INSERT INTO `no` (`yes`) VALUES (%s)", (5,)), + ] + }, + ), + ( + "INSERT INTO T (f1, f2) VALUES (1, 2)", + None, + {"sql_params_list": [("INSERT INTO T (f1, f2) VALUES (1, 2)", None)]}, + ), + ( + "INSERT INTO `no` (`yes`, tiff) VALUES (%s, LOWER(%s)), (%s, %s), (%s, %s)", + (1, "FOO", 5, 10, 11, 29), + { + "sql_params_list": [ + ( + "INSERT INTO `no` (`yes`, tiff) VALUES(%s, LOWER(%s))", + (1, "FOO"), + ), + ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (5, 10)), + ("INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", (11, 29)), + ] + }, + ), + ] + + sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" + with self.assertRaises(ProgrammingError): + parse_insert(sql, None) + + for sql, params, want in cases: + with self.subTest(sql=sql): + got = parse_insert(sql, params) + self.assertEqual(got, want, "Mismatch with parse_insert of `%s`" % sql) + + @unittest.skipIf(skip_condition, skip_message) + def test_parse_insert_invalid(self): + from google.cloud.spanner_dbapi import exceptions + from google.cloud.spanner_dbapi.parse_utils import parse_insert + + cases = [ + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)", + [1, 2, 3, 4, 5, 6, 7], + "len\\(params\\)=7 MUST be a multiple of len\\(pyformat_args\\)=3", + ), + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s))", + [1, 2, 3, 4, 5, 6, 7], + "Invalid length: VALUES\\(...\\) len: 6 != len\\(params\\): 7", + ), + ( + "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s)))", + [1, 2, 3, 4, 5, 6], + "VALUES: expected `,` got \\) in \\)", + ), + ] + + for sql, params, wantException in cases: + with self.subTest(sql=sql): + self.assertRaisesRegex( + exceptions.ProgrammingError, + wantException, + lambda: parse_insert(sql, params), + ) + + @unittest.skipIf(skip_condition, skip_message) + def test_rows_for_insert_or_update(self): + from google.cloud.spanner_dbapi.parse_utils import rows_for_insert_or_update + from google.cloud.spanner_dbapi.exceptions import Error + + with self.assertRaises(Error): + rows_for_insert_or_update([0], [[]]) + + with self.assertRaises(Error): + rows_for_insert_or_update([0], None, ["0", "%s"]) + + cases = [ + ( + ["id", "app", "name"], + [(5, "ap", "n"), (6, "bp", "m")], + None, + [(5, "ap", "n"), (6, "bp", "m")], + ), + ( + ["app", "name"], + [("ap", "n"), ("bp", "m")], + None, + [("ap", "n"), ("bp", "m")], + ), + ( + ["app", "name", "fn"], + ["ap", "n", "f1", "bp", "m", "f2", "cp", "o", "f3"], + ["(%s, %s, %s)", "(%s, %s, %s)", "(%s, %s, %s)"], + [("ap", "n", "f1"), ("bp", "m", "f2"), ("cp", "o", "f3")], + ), + ( + ["app", "name", "fn", "ln"], + [ + ("ap", "n", (45, "nested"), "ll"), + ("bp", "m", "f2", "mt"), + ("fp", "cp", "o", "f3"), + ], + None, + [ + ("ap", "n", (45, "nested"), "ll"), + ("bp", "m", "f2", "mt"), + ("fp", "cp", "o", "f3"), + ], + ), + (["app", "name", "fn"], ["ap", "n", "f1"], None, [("ap", "n", "f1")]), + ] + + for i, (columns, params, pyformat_args, want) in enumerate(cases): + with self.subTest(i=i): + got = rows_for_insert_or_update(columns, params, pyformat_args) + self.assertEqual(got, want) + + @unittest.skipIf(skip_condition, skip_message) + def test_sql_pyformat_args_to_spanner(self): + import decimal + + from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner + + cases = [ + ( + ( + "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s", + (10, "abc", "y**$22l3f"), + ), + ( + "SELECT * from t WHERE f1=@a0, f2 = @a1, f3=@a2", + {"a0": 10, "a1": "abc", "a2": "y**$22l3f"}, + ), + ), + ( + ( + "INSERT INTO t (f1, f2, f2) VALUES (%s, %s, %s)", + ("app", "name", "applied"), + ), + ( + "INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)", + {"a0": "app", "a1": "name", "a2": "applied"}, + ), + ), + ( + ( + "INSERT INTO t (f1, f2, f2) VALUES (%(f1)s, %(f2)s, %(f3)s)", + {"f1": "app", "f2": "name", "f3": "applied"}, + ), + ( + "INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)", + {"a0": "app", "a1": "name", "a2": "applied"}, + ), + ), + ( + # Intentionally using a dict with more keys than will be resolved. + ("SELECT * from t WHERE f1=%(f1)s", {"f1": "app", "f2": "name"}), + ("SELECT * from t WHERE f1=@a0", {"a0": "app"}), + ), + ( + # No args to replace, we MUST return the original params dict + # since it might be useful to pass to the next user. + ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), + ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), + ), + ( + ( + "SELECT (an.p + %s) AS np FROM an WHERE (an.p + %s) = %s", + (1, 1.0, decimal.Decimal("31")), + ), + ( + "SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2", + {"a0": 1, "a1": 1.0, "a2": str(31)}, + ), + ), + ] + for ((sql_in, params), sql_want) in cases: + with self.subTest(sql=sql_in): + got_sql, got_named_args = sql_pyformat_args_to_spanner(sql_in, params) + want_sql, want_named_args = sql_want + self.assertEqual(got_sql, want_sql, "SQL does not match") + self.assertEqual( + got_named_args, want_named_args, "Named args do not match" + ) + + @unittest.skipIf(skip_condition, skip_message) + def test_sql_pyformat_args_to_spanner_invalid(self): + from google.cloud.spanner_dbapi import exceptions + from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner + + cases = [ + ( + "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s, extra=%s", + (10, "abc", "y**$22l3f"), + ) + ] + for sql, params in cases: + with self.subTest(sql=sql): + self.assertRaisesRegex( + exceptions.Error, + "pyformat_args mismatch", + lambda: sql_pyformat_args_to_spanner(sql, params), + ) + + def test_cast_for_spanner(self): + import decimal + + from google.cloud.spanner_dbapi.parse_utils import cast_for_spanner + + dec = 3 + value = decimal.Decimal(dec) + self.assertEqual(cast_for_spanner(value), str(dec)) + self.assertEqual(cast_for_spanner(5), 5) + self.assertEqual(cast_for_spanner("string"), "string") + + @unittest.skipIf(skip_condition, skip_message) + def test_get_param_types(self): + import datetime + + from google.cloud.spanner_dbapi.parse_utils import DateStr + from google.cloud.spanner_dbapi.parse_utils import TimestampStr + from google.cloud.spanner_dbapi.parse_utils import get_param_types + + params = { + "a1": 10, + "b1": "string", + "c1": 10.39, + "d1": TimestampStr("2005-08-30T01:01:01.000001Z"), + "e1": DateStr("2019-12-05"), + "f1": True, + "g1": datetime.datetime(2011, 9, 1, 13, 20, 30), + "h1": datetime.date(2011, 9, 1), + "i1": b"bytes", + "j1": None, + } + want_types = { + "a1": param_types.INT64, + "b1": param_types.STRING, + "c1": param_types.FLOAT64, + "d1": param_types.TIMESTAMP, + "e1": param_types.DATE, + "f1": param_types.BOOL, + "g1": param_types.TIMESTAMP, + "h1": param_types.DATE, + "i1": param_types.BYTES, + } + got_types = get_param_types(params) + self.assertEqual(got_types, want_types) + + def test_get_param_types_none(self): + from google.cloud.spanner_dbapi.parse_utils import get_param_types + + self.assertEqual(get_param_types(None), None) + + @unittest.skipIf(skip_condition, skip_message) + def test_ensure_where_clause(self): + from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause + + cases = [ + ( + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", + ), + ( + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", + "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5 WHERE 1=1", + ), + ( + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", + ), + ( + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ), + ( + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", + ), + ("DELETE * FROM TABLE", "DELETE * FROM TABLE WHERE 1=1"), + ] + + for sql, want in cases: + with self.subTest(sql=sql): + got = ensure_where_clause(sql) + self.assertEqual(got, want) + + @unittest.skipIf(skip_condition, skip_message) + def test_escape_name(self): + from google.cloud.spanner_dbapi.parse_utils import escape_name + + cases = ( + ("SELECT", "`SELECT`"), + ("dashed-value", "`dashed-value`"), + ("with space", "`with space`"), + ("name", "name"), + ("", ""), + ) + for name, want in cases: + with self.subTest(name=name): + got = escape_name(name) + self.assertEqual(got, want) diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py new file mode 100644 index 0000000000..2343800489 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -0,0 +1,297 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mock +import sys +import unittest + + +class TestParser(unittest.TestCase): + + skip_condition = sys.version_info[0] < 3 + skip_message = "Subtests are not supported in Python 2" + + @unittest.skipIf(skip_condition, skip_message) + def test_func(self): + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + + cases = [ + ("_91())", ")", func("_91", a_args([]))), + ("_a()", "", func("_a", a_args([]))), + ("___()", "", func("___", a_args([]))), + ("abc()", "", func("abc", a_args([]))), + ( + "AF112(%s, LOWER(%s, %s), rand(%s, %s, TAN(%s, %s)))", + "", + func( + "AF112", + a_args( + [ + pyfmt_str, + func("LOWER", a_args([pyfmt_str, pyfmt_str])), + func( + "rand", + a_args( + [ + pyfmt_str, + pyfmt_str, + func("TAN", a_args([pyfmt_str, pyfmt_str])), + ] + ), + ), + ] + ), + ), + ), + ] + + for text, want_unconsumed, want_parsed in cases: + with self.subTest(text=text): + got_unconsumed, got_parsed = expect(text, FUNC) + self.assertEqual(got_parsed, want_parsed) + self.assertEqual(got_unconsumed, want_unconsumed) + + @unittest.skipIf(skip_condition, skip_message) + def test_func_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import expect + + cases = [ + ("", "FUNC: `` does not begin with `a-zA-z` nor a `_`"), + ("91", "FUNC: `91` does not begin with `a-zA-z` nor a `_`"), + ("_91", "supposed to begin with `\\(`"), + ("_91(", "supposed to end with `\\)`"), + ("_.()", "supposed to begin with `\\(`"), + ("_a.b()", "supposed to begin with `\\(`"), + ] + + for text, wantException in cases: + with self.subTest(text=text): + self.assertRaisesRegex( + ProgrammingError, wantException, lambda: expect(text, FUNC) + ) + + def test_func_eq(self): + from google.cloud.spanner_dbapi.parser import func + + func1 = func("func1", None) + func2 = func("func2", None) + self.assertFalse(func1 == object) + self.assertFalse(func1 == func2) + func2.name = func1.name + func1.args = 0 + func2.args = "0" + self.assertFalse(func1 == func2) + func1.args = [0] + func2.args = [0, 0] + self.assertFalse(func1 == func2) + func2.args = func1.args + self.assertTrue(func1 == func2) + + @unittest.skipIf(skip_condition, skip_message) + def test_a_args(self): + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + + cases = [ + ("()", "", a_args([])), + ("(%s)", "", a_args([pyfmt_str])), + ("(%s,)", "", a_args([pyfmt_str])), + ("(%s),", ",", a_args([pyfmt_str])), + ( + "(%s,%s, f1(%s, %s))", + "", + a_args( + [pyfmt_str, pyfmt_str, func("f1", a_args([pyfmt_str, pyfmt_str]))] + ), + ), + ] + + for text, want_unconsumed, want_parsed in cases: + with self.subTest(text=text): + got_unconsumed, got_parsed = expect(text, ARGS) + self.assertEqual(got_parsed, want_parsed) + self.assertEqual(got_unconsumed, want_unconsumed) + + @unittest.skipIf(skip_condition, skip_message) + def test_a_args_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import expect + + cases = [ + ("", "ARGS: supposed to begin with `\\(`"), + ("(", "ARGS: supposed to end with `\\)`"), + (")", "ARGS: supposed to begin with `\\(`"), + ("(%s,%s, f1(%s, %s), %s", "ARGS: supposed to end with `\\)`"), + ] + + for text, wantException in cases: + with self.subTest(text=text): + self.assertRaisesRegex( + ProgrammingError, wantException, lambda: expect(text, ARGS) + ) + + def test_a_args_has_expr(self): + from google.cloud.spanner_dbapi.parser import a_args + + self.assertFalse(a_args([]).has_expr()) + self.assertTrue(a_args([[0]]).has_expr()) + + def test_a_args_eq(self): + from google.cloud.spanner_dbapi.parser import a_args + + a1 = a_args([0]) + self.assertFalse(a1 == object()) + a2 = a_args([0, 0]) + self.assertFalse(a1 == a2) + a1.argv = [0, 1] + self.assertFalse(a1 == a2) + a2.argv = [0, 1] + self.assertTrue(a1 == a2) + + def test_a_args_homogeneous(self): + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import terminal + + a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) + self.assertTrue(a_obj.homogenous()) + + a_obj = a_args([a_args([[object()]]) for _ in range(10)]) + self.assertFalse(a_obj.homogenous()) + + def test_a_args__is_equal_length(self): + from google.cloud.spanner_dbapi.parser import a_args + + a_obj = a_args([]) + self.assertTrue(a_obj._is_equal_length()) + + @unittest.skipIf(skip_condition, "Python 2 has an outdated iterator definition") + @unittest.skipIf( + skip_condition, "Python 2 does not support 0-argument super() calls" + ) + def test_values(self): + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import terminal + from google.cloud.spanner_dbapi.parser import values + + a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) + self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) + + def test_expect(self): + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import EXPR + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import pyfmt_str + from google.cloud.spanner_dbapi import exceptions + + with self.assertRaises(exceptions.ProgrammingError): + expect(word="", token=ARGS) + with self.assertRaises(exceptions.ProgrammingError): + expect(word="ABC", token=ARGS) + with self.assertRaises(exceptions.ProgrammingError): + expect(word="(", token=ARGS) + + expected = "", pyfmt_str + self.assertEqual(expect("%s", EXPR), expected) + + expected = expect("function()", FUNC) + self.assertEqual(expect("function()", EXPR), expected) + + with self.assertRaises(exceptions.ProgrammingError): + expect(word="", token="ABC") + + @unittest.skipIf(skip_condition, skip_message) + def test_expect_values(self): + from google.cloud.spanner_dbapi.parser import VALUES + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + from google.cloud.spanner_dbapi.parser import values + + cases = [ + ("VALUES ()", "", values([a_args([])])), + ("VALUES", "", values([])), + ("VALUES(%s)", "", values([a_args([pyfmt_str])])), + (" VALUES (%s) ", "", values([a_args([pyfmt_str])])), + ("VALUES(%s, %s)", "", values([a_args([pyfmt_str, pyfmt_str])])), + ( + "VALUES(%s, %s, LOWER(%s, %s))", + "", + values( + [ + a_args( + [ + pyfmt_str, + pyfmt_str, + func("LOWER", a_args([pyfmt_str, pyfmt_str])), + ] + ) + ] + ), + ), + ( + "VALUES (UPPER(%s)), (%s)", + "", + values( + [a_args([func("UPPER", a_args([pyfmt_str]))]), a_args([pyfmt_str])] + ), + ), + ] + + for text, want_unconsumed, want_parsed in cases: + with self.subTest(text=text): + got_unconsumed, got_parsed = expect(text, VALUES) + self.assertEqual(got_parsed, want_parsed) + self.assertEqual(got_unconsumed, want_unconsumed) + + @unittest.skipIf(skip_condition, skip_message) + def test_expect_values_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import VALUES + from google.cloud.spanner_dbapi.parser import expect + + cases = [ + ("", "VALUES: `` does not start with VALUES"), + ( + "VALUES(%s, %s, (%s, %s))", + "FUNC: `\\(%s, %s\\)\\)` does not begin with `a-zA-z` nor a `_`", + ), + ("VALUES(%s),,", "ARGS: supposed to begin with `\\(` in `,`"), + ] + + for text, wantException in cases: + with self.subTest(text=text): + self.assertRaisesRegex( + ProgrammingError, wantException, lambda: expect(text, VALUES) + ) + + def test_as_values(self): + from google.cloud.spanner_dbapi.parser import as_values + + values = (1, 2) + with mock.patch( + "google.cloud.spanner_dbapi.parser.parse_values", return_value=values + ): + self.assertEqual(as_values(None), values[1]) diff --git a/tests/unit/spanner_dbapi/test_types.py b/tests/unit/spanner_dbapi/test_types.py new file mode 100644 index 0000000000..8c9dbe6c2b --- /dev/null +++ b/tests/unit/spanner_dbapi/test_types.py @@ -0,0 +1,71 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from time import timezone + + +class TestTypes(unittest.TestCase): + + TICKS = 1572822862.9782631 + timezone # Sun 03 Nov 2019 23:14:22 UTC + + def test__date_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + + actual = types._date_from_ticks(self.TICKS) + expected = datetime.date(2019, 11, 3) + + self.assertEqual(actual, expected) + + def test__time_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + + actual = types._time_from_ticks(self.TICKS) + expected = datetime.time(23, 14, 22) + + self.assertEqual(actual, expected) + + def test__timestamp_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + + actual = types._timestamp_from_ticks(self.TICKS) + expected = datetime.datetime(2019, 11, 3, 23, 14, 22) + + self.assertEqual(actual, expected) + + def test_type_equal(self): + from google.cloud.spanner_dbapi import types + + self.assertEqual(types.BINARY, "TYPE_CODE_UNSPECIFIED") + self.assertEqual(types.BINARY, "BYTES") + self.assertEqual(types.BINARY, "ARRAY") + self.assertEqual(types.BINARY, "STRUCT") + self.assertNotEqual(types.BINARY, "STRING") + + self.assertEqual(types.NUMBER, "BOOL") + self.assertEqual(types.NUMBER, "INT64") + self.assertEqual(types.NUMBER, "FLOAT64") + self.assertEqual(types.NUMBER, "NUMERIC") + self.assertNotEqual(types.NUMBER, "STRING") + + self.assertEqual(types.DATETIME, "TIMESTAMP") + self.assertEqual(types.DATETIME, "DATE") + self.assertNotEqual(types.DATETIME, "STRING") diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py new file mode 100644 index 0000000000..4fe94f30a7 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -0,0 +1,87 @@ +# Copyright 2020 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + + +class TestUtils(unittest.TestCase): + + skip_condition = sys.version_info[0] < 3 + skip_message = "Subtests are not supported in Python 2" + + @unittest.skipIf(skip_condition, skip_message) + def test_PeekIterator(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + cases = [ + ("list", [1, 2, 3, 4, 6, 7], [1, 2, 3, 4, 6, 7]), + ("iter_from_list", iter([1, 2, 3, 4, 6, 7]), [1, 2, 3, 4, 6, 7]), + ("tuple", ("a", 12, 0xFF), ["a", 12, 0xFF]), + ("iter_from_tuple", iter(("a", 12, 0xFF)), ["a", 12, 0xFF]), + ("no_args", (), []), + ] + + for name, data_in, expected in cases: + with self.subTest(name=name): + pitr = PeekIterator(data_in) + actual = list(pitr) + self.assertEqual(actual, expected) + + @unittest.skipIf(skip_condition, "Python 2 has an outdated iterator definition") + def test_peekIterator_list_rows_converted_to_tuples(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + # Cloud Spanner returns results in lists e.g. [result]. + # PeekIterator is used by BaseCursor in its fetch* methods. + # This test ensures that anything passed into PeekIterator + # will be returned as a tuple. + pit = PeekIterator([["a"], ["b"], ["c"], ["d"], ["e"]]) + got = list(pit) + want = [("a",), ("b",), ("c",), ("d",), ("e",)] + self.assertEqual(got, want, "Rows of type list must be returned as tuples") + + seventeen = PeekIterator([[17]]) + self.assertEqual(list(seventeen), [(17,)]) + + pit = PeekIterator([["%", "%d"]]) + self.assertEqual(next(pit), ("%", "%d")) + + pit = PeekIterator([("Clark", "Kent")]) + self.assertEqual(next(pit), ("Clark", "Kent")) + + @unittest.skipIf(skip_condition, "Python 2 has an outdated iterator definition") + def test_peekIterator_nonlist_rows_unconverted(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + pi = PeekIterator(["a", "b", "c", "d", "e"]) + got = list(pi) + want = ["a", "b", "c", "d", "e"] + self.assertEqual(got, want, "Values should be returned unchanged") + + @unittest.skipIf(skip_condition, skip_message) + def test_backtick_unicode(self): + from google.cloud.spanner_dbapi.utils import backtick_unicode + + cases = [ + ("SELECT (1) as foo WHERE 1=1", "SELECT (1) as foo WHERE 1=1"), + ("SELECT (1) as föö", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö`", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö` `umläut", "SELECT (1) as `föö` `umläut"), + ("SELECT (1) as `föö", "SELECT (1) as `föö"), + ] + for sql, want in cases: + with self.subTest(sql=sql): + got = backtick_unicode(sql) + self.assertEqual(got, want)