From 818324708e9ca46fbd80c47745bdf38e8a1a069c Mon Sep 17 00:00:00 2001 From: Ilya Gurov Date: Thu, 26 Nov 2020 01:28:13 +0300 Subject: [PATCH] refactor!: erase dbapi directory and all the related tests (#554) BREAKING CHANGE: DBAPI code was moved into python-spanner in https://github.com/googleapis/python-spanner/pull/160. This change removes it from this repo and bumps the dependency on python-spanner to 2.0.0, the first released version to include DBAPI. --- .../system-tests-against-emulator.yaml | 32 - django_spanner/compiler.py | 3 +- django_spanner/introspection.py | 18 +- google/__init__.py | 0 google/cloud/__init__.py | 0 google/cloud/spanner_dbapi/__init__.py | 85 --- google/cloud/spanner_dbapi/_helpers.py | 159 ----- google/cloud/spanner_dbapi/connection.py | 272 --------- google/cloud/spanner_dbapi/cursor.py | 329 ----------- google/cloud/spanner_dbapi/exceptions.py | 94 --- google/cloud/spanner_dbapi/parse_utils.py | 545 ------------------ google/cloud/spanner_dbapi/parser.py | 246 -------- google/cloud/spanner_dbapi/types.py | 98 ---- google/cloud/spanner_dbapi/utils.py | 81 --- google/cloud/spanner_dbapi/version.py | 11 - noxfile.py | 67 +-- tests/spanner_dbapi/test_connect.py | 135 ----- tests/spanner_dbapi/test_connection.py | 79 --- tests/system/test_system.py | 295 ---------- tests/unit/spanner_dbapi/__init__.py | 5 - tests/unit/spanner_dbapi/test__helpers.py | 130 ----- tests/unit/spanner_dbapi/test_connection.py | 337 ----------- tests/unit/spanner_dbapi/test_cursor.py | 460 --------------- tests/unit/spanner_dbapi/test_globals.py | 20 - tests/unit/spanner_dbapi/test_parse_utils.py | 444 -------------- tests/unit/spanner_dbapi/test_parser.py | 288 --------- tests/unit/spanner_dbapi/test_types.py | 63 -- tests/unit/spanner_dbapi/test_utils.py | 72 --- 28 files changed, 13 insertions(+), 4355 deletions(-) delete mode 100644 .github/workflows/system-tests-against-emulator.yaml delete mode 100644 google/__init__.py delete mode 100644 google/cloud/__init__.py delete mode 100644 google/cloud/spanner_dbapi/__init__.py delete mode 100644 google/cloud/spanner_dbapi/_helpers.py delete mode 100644 google/cloud/spanner_dbapi/connection.py delete mode 100644 google/cloud/spanner_dbapi/cursor.py delete mode 100644 google/cloud/spanner_dbapi/exceptions.py delete mode 100644 google/cloud/spanner_dbapi/parse_utils.py delete mode 100644 google/cloud/spanner_dbapi/parser.py delete mode 100644 google/cloud/spanner_dbapi/types.py delete mode 100644 google/cloud/spanner_dbapi/utils.py delete mode 100644 google/cloud/spanner_dbapi/version.py delete mode 100644 tests/spanner_dbapi/test_connect.py delete mode 100644 tests/spanner_dbapi/test_connection.py delete mode 100644 tests/system/test_system.py delete mode 100644 tests/unit/spanner_dbapi/__init__.py delete mode 100644 tests/unit/spanner_dbapi/test__helpers.py delete mode 100644 tests/unit/spanner_dbapi/test_connection.py delete mode 100644 tests/unit/spanner_dbapi/test_cursor.py delete mode 100644 tests/unit/spanner_dbapi/test_globals.py delete mode 100644 tests/unit/spanner_dbapi/test_parse_utils.py delete mode 100644 tests/unit/spanner_dbapi/test_parser.py delete mode 100644 tests/unit/spanner_dbapi/test_types.py delete mode 100644 tests/unit/spanner_dbapi/test_utils.py diff --git a/.github/workflows/system-tests-against-emulator.yaml b/.github/workflows/system-tests-against-emulator.yaml deleted file mode 100644 index 1ae7af1273..0000000000 --- a/.github/workflows/system-tests-against-emulator.yaml +++ /dev/null @@ -1,32 +0,0 @@ -on: - push: - branches: - - master - pull_request: -name: Run Spanner integration tests against emulator -jobs: - system-tests: - runs-on: ubuntu-latest - - services: - emulator: - image: gcr.io/cloud-spanner-emulator/emulator:latest - ports: - - 9010:9010 - - 9020:9020 - - steps: - - name: Checkout code - uses: actions/checkout@v2 - - name: Setup Python - uses: actions/setup-python@v2 - with: - python-version: 3.8 - - name: Install nox - run: python -m pip install nox - - name: Run system tests - run: nox -s system-3.8 - env: - SPANNER_EMULATOR_HOST: localhost:9010 - GOOGLE_CLOUD_PROJECT: emulator-test-project - GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE: true diff --git a/django_spanner/compiler.py b/django_spanner/compiler.py index 106686d445..61d980a6f3 100644 --- a/django_spanner/compiler.py +++ b/django_spanner/compiler.py @@ -12,7 +12,8 @@ SQLInsertCompiler as BaseSQLInsertCompiler, SQLUpdateCompiler as BaseSQLUpdateCompiler, ) -from django.db.utils import DatabaseError, add_dummy_where +from django.db.utils import DatabaseError +from django_spanner.utils import add_dummy_where class SQLCompiler(BaseSQLCompiler): diff --git a/django_spanner/introspection.py b/django_spanner/introspection.py index ab9d29aa3f..2928c84798 100644 --- a/django_spanner/introspection.py +++ b/django_spanner/introspection.py @@ -10,22 +10,22 @@ TableInfo, ) from django.db.models import Index -from google.cloud.spanner_v1.proto import type_pb2 +from google.cloud.spanner_v1 import TypeCode class DatabaseIntrospection(BaseDatabaseIntrospection): data_types_reverse = { - type_pb2.BOOL: "BooleanField", - type_pb2.BYTES: "BinaryField", - type_pb2.DATE: "DateField", - type_pb2.FLOAT64: "FloatField", - type_pb2.INT64: "IntegerField", - type_pb2.STRING: "CharField", - type_pb2.TIMESTAMP: "DateTimeField", + TypeCode.BOOL: "BooleanField", + TypeCode.BYTES: "BinaryField", + TypeCode.DATE: "DateField", + TypeCode.FLOAT64: "FloatField", + TypeCode.INT64: "IntegerField", + TypeCode.STRING: "CharField", + TypeCode.TIMESTAMP: "DateTimeField", } def get_field_type(self, data_type, description): - if data_type == type_pb2.STRING and description.internal_size == "MAX": + if data_type == TypeCode.STRING and description.internal_size == "MAX": return "TextField" return super().get_field_type(data_type, description) diff --git a/google/__init__.py b/google/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/google/cloud/__init__.py b/google/cloud/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py deleted file mode 100644 index 7695c0058f..0000000000 --- a/google/cloud/spanner_dbapi/__init__.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Connection-based DB API for Cloud Spanner.""" - -from google.cloud.spanner_dbapi.connection import Connection -from google.cloud.spanner_dbapi.connection import connect - -from google.cloud.spanner_dbapi.cursor import Cursor - -from google.cloud.spanner_dbapi.exceptions import DatabaseError -from google.cloud.spanner_dbapi.exceptions import DataError -from google.cloud.spanner_dbapi.exceptions import Error -from google.cloud.spanner_dbapi.exceptions import IntegrityError -from google.cloud.spanner_dbapi.exceptions import InterfaceError -from google.cloud.spanner_dbapi.exceptions import InternalError -from google.cloud.spanner_dbapi.exceptions import NotSupportedError -from google.cloud.spanner_dbapi.exceptions import OperationalError -from google.cloud.spanner_dbapi.exceptions import ProgrammingError -from google.cloud.spanner_dbapi.exceptions import Warning - -from google.cloud.spanner_dbapi.parse_utils import get_param_types - -from google.cloud.spanner_dbapi.types import BINARY -from google.cloud.spanner_dbapi.types import DATETIME -from google.cloud.spanner_dbapi.types import NUMBER -from google.cloud.spanner_dbapi.types import ROWID -from google.cloud.spanner_dbapi.types import STRING -from google.cloud.spanner_dbapi.types import Binary -from google.cloud.spanner_dbapi.types import Date -from google.cloud.spanner_dbapi.types import DateFromTicks -from google.cloud.spanner_dbapi.types import Time -from google.cloud.spanner_dbapi.types import TimeFromTicks -from google.cloud.spanner_dbapi.types import Timestamp -from google.cloud.spanner_dbapi.types import TimestampStr -from google.cloud.spanner_dbapi.types import TimestampFromTicks - -from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT - -apilevel = "2.0" # supports DP-API 2.0 level. -paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. - -# Threads may share the module, but not connections. This is a paranoid threadsafety -# level, but it is necessary for starters to use when debugging failures. -# Eventually once transactions are working properly, we'll update the -# threadsafety level. -threadsafety = 1 - - -__all__ = [ - "Connection", - "connect", - "Cursor", - "DatabaseError", - "DataError", - "Error", - "IntegrityError", - "InterfaceError", - "InternalError", - "NotSupportedError", - "OperationalError", - "ProgrammingError", - "Warning", - "DEFAULT_USER_AGENT", - "apilevel", - "paramstyle", - "threadsafety", - "get_param_types", - "Binary", - "Date", - "DateFromTicks", - "Time", - "TimeFromTicks", - "Timestamp", - "TimestampFromTicks", - "BINARY", - "STRING", - "NUMBER", - "DATETIME", - "ROWID", - "TimestampStr", -] diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py deleted file mode 100644 index f581fdebbd..0000000000 --- a/google/cloud/spanner_dbapi/_helpers.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -from google.cloud.spanner_dbapi.parse_utils import get_param_types -from google.cloud.spanner_dbapi.parse_utils import parse_insert -from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner -from google.cloud.spanner_v1 import param_types - - -SQL_LIST_TABLES = """ - SELECT - t.table_name - FROM - information_schema.tables AS t - WHERE - t.table_catalog = '' and t.table_schema = '' - """ - -SQL_GET_TABLE_COLUMN_SCHEMA = """SELECT - COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - TABLE_SCHEMA = '' - AND - TABLE_NAME = @table_name - """ - -# This table maps spanner_types to Spanner's data type sizes as per -# https://cloud.google.com/spanner/docs/data-types#allowable-types -# It is used to map `display_size` to a known type for Cursor.description -# after a row fetch. -# Since ResultMetadata -# https://cloud.google.com/spanner/docs/reference/rest/v1/ResultSetMetadata -# does not send back the actual size, we have to lookup the respective size. -# Some fields' sizes are dependent upon the dynamic data hence aren't sent back -# by Cloud Spanner. -code_to_display_size = { - param_types.BOOL.code: 1, - param_types.DATE.code: 4, - param_types.FLOAT64.code: 8, - param_types.INT64.code: 8, - param_types.TIMESTAMP.code: 12, -} - - -def _execute_insert_heterogenous(transaction, sql_params_list): - for sql, params in sql_params_list: - sql, params = sql_pyformat_args_to_spanner(sql, params) - param_types = get_param_types(params) - res = transaction.execute_sql( - sql, params=params, param_types=param_types - ) - # TODO: File a bug with Cloud Spanner and the Python client maintainers - # about a lost commit when res isn't read from. - _ = list(res) - - -def _execute_insert_homogenous(transaction, parts): - # Perform an insert in one shot. - table = parts.get("table") - columns = parts.get("columns") - values = parts.get("values") - return transaction.insert(table, columns, values) - - -def handle_insert(connection, sql, params): - parts = parse_insert(sql, params) - - # The split between the two styles exists because: - # in the common case of multiple values being passed - # with simple pyformat arguments, - # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s) - # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)] - # we can take advantage of a single RPC with: - # transaction.insert(table, columns, values) - # instead of invoking: - # with transaction: - # for sql, params in sql_params_list: - # transaction.execute_sql(sql, params, param_types) - # which invokes more RPCs and is more costly. - - if parts.get("homogenous"): - # The common case of multiple values being passed in - # non-complex pyformat args and need to be uploaded in one RPC. - return connection.database.run_in_transaction( - _execute_insert_homogenous, parts - ) - else: - # All the other cases that are esoteric and need - # transaction.execute_sql - sql_params_list = parts.get("sql_params_list") - return connection.database.run_in_transaction( - _execute_insert_heterogenous, sql_params_list - ) - - -class ColumnInfo: - """Row column description object.""" - - def __init__( - self, - name, - type_code, - display_size=None, - internal_size=None, - precision=None, - scale=None, - null_ok=False, - ): - self.name = name - self.type_code = type_code - self.display_size = display_size - self.internal_size = internal_size - self.precision = precision - self.scale = scale - self.null_ok = null_ok - - self.fields = ( - self.name, - self.type_code, - self.display_size, - self.internal_size, - self.precision, - self.scale, - self.null_ok, - ) - - def __repr__(self): - return self.__str__() - - def __getitem__(self, index): - return self.fields[index] - - def __str__(self): - str_repr = ", ".join( - filter( - lambda part: part is not None, - [ - "name='%s'" % self.name, - "type_code=%d" % self.type_code, - "display_size=%d" % self.display_size - if self.display_size - else None, - "internal_size=%d" % self.internal_size - if self.internal_size - else None, - "precision='%s'" % self.precision - if self.precision - else None, - "scale='%s'" % self.scale if self.scale else None, - "null_ok='%s'" % self.null_ok if self.null_ok else None, - ], - ) - ) - return "ColumnInfo(%s)" % str_repr diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py deleted file mode 100644 index beb05a3173..0000000000 --- a/google/cloud/spanner_dbapi/connection.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""DB-API Connection for the Google Cloud Spanner.""" - -import warnings - -from google.api_core.gapic_v1.client_info import ClientInfo -from google.cloud import spanner_v1 as spanner - -from google.cloud.spanner_dbapi.cursor import Cursor -from google.cloud.spanner_dbapi.exceptions import InterfaceError -from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT -from google.cloud.spanner_dbapi.version import PY_VERSION - - -AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" - - -class Connection: - """Representation of a DB-API connection to a Cloud Spanner database. - - You most likely don't need to instantiate `Connection` objects - directly, use the `connect` module function instead. - - :type instance: :class:`~google.cloud.spanner_v1.instance.Instance` - :param instance: Cloud Spanner instance to connect to. - - :type database: :class:`~google.cloud.spanner_v1.database.Database` - :param database: The database to which the connection is linked. - """ - - def __init__(self, instance, database): - self._instance = instance - self._database = database - self._ddl_statements = [] - - self._transaction = None - self._session = None - - self.is_closed = False - self._autocommit = False - # indicator to know if the session pool used by - # this connection should be cleared on the - # connection close - self._own_pool = True - - @property - def autocommit(self): - """Autocommit mode flag for this connection. - - :rtype: bool - :returns: Autocommit mode flag value. - """ - return self._autocommit - - @autocommit.setter - def autocommit(self, value): - """Change this connection autocommit mode. Setting this value to True - while a transaction is active will commit the current transaction. - - :type value: bool - :param value: New autocommit mode state. - """ - if value and not self._autocommit: - self.commit() - - self._autocommit = value - - @property - def database(self): - """Database to which this connection relates. - - :rtype: :class:`~google.cloud.spanner_v1.database.Database` - :returns: The related database object. - """ - return self._database - - @property - def instance(self): - """Instance to which this connection relates. - - :rtype: :class:`~google.cloud.spanner_v1.instance.Instance` - :returns: The related instance object. - """ - return self._instance - - def _session_checkout(self): - """Get a Cloud Spanner session from the pool. - - If there is already a session associated with - this connection, it'll be used instead. - - :rtype: :class:`google.cloud.spanner_v1.session.Session` - :returns: Cloud Spanner session object ready to use. - """ - if not self._session: - self._session = self.database._pool.get() - - return self._session - - def _release_session(self): - """Release the currently used Spanner session. - - The session will be returned into the sessions pool. - """ - self.database._pool.put(self._session) - self._session = None - - def transaction_checkout(self): - """Get a Cloud Spanner transaction. - - Begin a new transaction, if there is no transaction in - this connection yet. Return the begun one otherwise. - - The method is non operational in autocommit mode. - - :rtype: :class:`google.cloud.spanner_v1.transaction.Transaction` - :returns: A Cloud Spanner transaction object, ready to use. - """ - if not self.autocommit: - if ( - not self._transaction - or self._transaction.committed - or self._transaction.rolled_back - ): - self._transaction = self._session_checkout().transaction() - self._transaction.begin() - - return self._transaction - - def _raise_if_closed(self): - """Helper to check the connection state before running a query. - Raises an exception if this connection is closed. - - :raises: :class:`InterfaceError`: if this connection is closed. - """ - if self.is_closed: - raise InterfaceError("connection is already closed") - - def close(self): - """Closes this connection. - - The connection will be unusable from this point forward. If the - connection has an active transaction, it will be rolled back. - """ - if ( - self._transaction - and not self._transaction.committed - and not self._transaction.rolled_back - ): - self._transaction.rollback() - - if self._own_pool: - self.database._pool.clear() - - self.is_closed = True - - def commit(self): - """Commits any pending transaction to the database. - - This method is non-operational in autocommit mode. - """ - if self._autocommit: - warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) - elif self._transaction: - self._transaction.commit() - self._release_session() - - def rollback(self): - """Rolls back any pending transaction. - - This is a no-op if there is no active transaction or if the connection - is in autocommit mode. - """ - if self._autocommit: - warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) - elif self._transaction: - self._transaction.rollback() - self._release_session() - - def cursor(self): - """Factory to create a DB-API Cursor.""" - self._raise_if_closed() - - return Cursor(self) - - def run_prior_DDL_statements(self): - self._raise_if_closed() - - if self._ddl_statements: - ddl_statements = self._ddl_statements - self._ddl_statements = [] - - return self.database.update_ddl(ddl_statements).result() - - def __enter__(self): - return self - - def __exit__(self, etype, value, traceback): - self.commit() - self.close() - - -def connect( - instance_id, - database_id, - project=None, - credentials=None, - pool=None, - user_agent=None, -): - """Creates a connection to a Google Cloud Spanner database. - - :type instance_id: str - :param instance_id: The ID of the instance to connect to. - - :type database_id: str - :param database_id: The ID of the database to connect to. - - :type project: str - :param project: (Optional) The ID of the project which owns the - instances, tables and data. If not provided, will - attempt to determine from the environment. - - :type credentials: :class:`~google.auth.credentials.Credentials` - :param credentials: (Optional) The authorization credentials to attach to - requests. These credentials identify this application - to the service. If none are specified, the client will - attempt to ascertain the credentials from the - environment. - - :type pool: Concrete subclass of - :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. - :param pool: (Optional). Session pool to be used by database. - - :type user_agent: str - :param user_agent: (Optional) User agent to be used with this connection's - requests. - - :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` - :returns: Connection object associated with the given Google Cloud Spanner - resource. - - :raises: :class:`ValueError` in case of given instance/database - doesn't exist. - """ - - client_info = ClientInfo( - user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION, - ) - - client = spanner.Client( - project=project, credentials=credentials, client_info=client_info, - ) - - instance = client.instance(instance_id) - if not instance.exists(): - raise ValueError("instance '%s' does not exist." % instance_id) - - database = instance.database(database_id, pool=pool) - if not database.exists(): - raise ValueError("database '%s' does not exist." % database_id) - - conn = Connection(instance, database) - if pool is not None: - conn._own_pool = False - - return conn diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py deleted file mode 100644 index e41f0f381a..0000000000 --- a/google/cloud/spanner_dbapi/cursor.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Database cursor for Google Cloud Spanner DB-API.""" - -from google.api_core.exceptions import AlreadyExists -from google.api_core.exceptions import FailedPrecondition -from google.api_core.exceptions import InternalServerError -from google.api_core.exceptions import InvalidArgument - -from collections import namedtuple - -from google.cloud import spanner_v1 as spanner - -from google.cloud.spanner_dbapi.exceptions import IntegrityError -from google.cloud.spanner_dbapi.exceptions import InterfaceError -from google.cloud.spanner_dbapi.exceptions import OperationalError -from google.cloud.spanner_dbapi.exceptions import ProgrammingError - -from google.cloud.spanner_dbapi import _helpers -from google.cloud.spanner_dbapi._helpers import ColumnInfo -from google.cloud.spanner_dbapi._helpers import code_to_display_size - -from google.cloud.spanner_dbapi import parse_utils -from google.cloud.spanner_dbapi.parse_utils import get_param_types -from google.cloud.spanner_dbapi.utils import PeekIterator - -_UNSET_COUNT = -1 - -ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) - - -class Cursor(object): - """Database cursor to manage the context of a fetch operation. - - :type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection` - :param connection: A DB-API connection to Google Cloud Spanner. - """ - - def __init__(self, connection): - self._itr = None - self._result_set = None - self._row_count = _UNSET_COUNT - self.connection = connection - self._is_closed = False - - # the number of rows to fetch at a time with fetchmany() - self.arraysize = 1 - - @property - def is_closed(self): - """The cursor close indicator. - - :rtype: bool - :returns: True if the cursor or the parent connection is closed, - otherwise False. - """ - return self._is_closed or self.connection.is_closed - - @property - def description(self): - """Read-only attribute containing a sequence of the following items: - - - ``name`` - - ``type_code`` - - ``display_size`` - - ``internal_size`` - - ``precision`` - - ``scale`` - - ``null_ok`` - """ - if not (self._result_set and self._result_set.metadata): - return None - - row_type = self._result_set.metadata.row_type - columns = [] - - for field in row_type.fields: - column_info = ColumnInfo( - name=field.name, - type_code=field.type.code, - # Size of the SQL type of the column. - display_size=code_to_display_size.get(field.type.code), - # Client perceived size of the column. - internal_size=field.ByteSize(), - ) - columns.append(column_info) - - return tuple(columns) - - @property - def rowcount(self): - """The number of rows produced by the last `.execute()`.""" - return self._row_count - - def _raise_if_closed(self): - """Raise an exception if this cursor is closed. - - Helper to check this cursor's state before running a - SQL/DDL/DML query. If the parent connection is - already closed it also raises an error. - - :raises: :class:`InterfaceError` if this cursor is closed. - """ - if self.is_closed: - raise InterfaceError("Cursor and/or connection is already closed.") - - def callproc(self, procname, args=None): - """A no-op, raising an error if the cursor or connection is closed.""" - self._raise_if_closed() - - def close(self): - """Closes this Cursor, making it unusable from this point forward.""" - self._is_closed = True - - def _do_execute_update(self, transaction, sql, params, param_types=None): - parse_utils.ensure_where_clause(sql) - sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) - - result = transaction.execute_update( - sql, params=params, param_types=get_param_types(params) - ) - self._itr = None - if type(result) == int: - self._row_count = result - - return result - - def execute(self, sql, args=None): - """Prepares and executes a Spanner database operation. - - :type sql: str - :param sql: A SQL query statement. - - :type args: list - :param args: Additional parameters to supplement the SQL query. - """ - if not self.connection: - raise ProgrammingError("Cursor is not connected to the database") - - self._raise_if_closed() - - self._result_set = None - - # Classify whether this is a read-only SQL statement. - try: - classification = parse_utils.classify_stmt(sql) - if classification == parse_utils.STMT_DDL: - self.connection._ddl_statements.append(sql) - return - - # For every other operation, we've got to ensure that - # any prior DDL statements were run. - # self._run_prior_DDL_statements() - self.connection.run_prior_DDL_statements() - - if not self.connection.autocommit: - transaction = self.connection.transaction_checkout() - - sql, params = parse_utils.sql_pyformat_args_to_spanner( - sql, args - ) - - self._result_set = transaction.execute_sql( - sql, params, param_types=get_param_types(params) - ) - self._itr = PeekIterator(self._result_set) - return - - if classification == parse_utils.STMT_NON_UPDATING: - self._handle_DQL(sql, args or None) - elif classification == parse_utils.STMT_INSERT: - _helpers.handle_insert(self.connection, sql, args or None) - else: - self.connection.database.run_in_transaction( - self._do_execute_update, sql, args or None - ) - except (AlreadyExists, FailedPrecondition) as e: - raise IntegrityError(e.details if hasattr(e, "details") else e) - except InvalidArgument as e: - raise ProgrammingError(e.details if hasattr(e, "details") else e) - except InternalServerError as e: - raise OperationalError(e.details if hasattr(e, "details") else e) - - def executemany(self, operation, seq_of_params): - """Execute the given SQL with every parameters set - from the given sequence of parameters. - - :type operation: str - :param operation: SQL code to execute. - - :type seq_of_params: list - :param seq_of_params: Sequence of additional parameters to run - the query with. - """ - self._raise_if_closed() - - for params in seq_of_params: - self.execute(operation, params) - - def fetchone(self): - """Fetch the next row of a query result set, returning a single - sequence, or None when no more data is available.""" - self._raise_if_closed() - - try: - return next(self) - except StopIteration: - return None - - def fetchmany(self, size=None): - """Fetch the next set of rows of a query result, returning a sequence - of sequences. An empty sequence is returned when no more rows are available. - - :type size: int - :param size: (Optional) The maximum number of results to fetch. - - :raises InterfaceError: - if the previous call to .execute*() did not produce any result set - or if no call was issued yet. - """ - self._raise_if_closed() - - if size is None: - size = self.arraysize - - items = [] - for i in range(size): - try: - items.append(tuple(self.__next__())) - except StopIteration: - break - - return items - - def fetchall(self): - """Fetch all (remaining) rows of a query result, returning them as - a sequence of sequences. - """ - self._raise_if_closed() - - return list(self.__iter__()) - - def nextset(self): - """A no-op, raising an error if the cursor or connection is closed.""" - self._raise_if_closed() - - def setinputsizes(self, sizes): - """A no-op, raising an error if the cursor or connection is closed.""" - self._raise_if_closed() - - def setoutputsize(self, size, column=None): - """A no-op, raising an error if the cursor or connection is closed.""" - self._raise_if_closed() - - def _handle_DQL(self, sql, params): - with self.connection.database.snapshot() as snapshot: - # Reference - # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql - sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) - res = snapshot.execute_sql( - sql, params=params, param_types=get_param_types(params) - ) - if type(res) == int: - self._row_count = res - self._itr = None - else: - # Immediately using: - # iter(response) - # here, because this Spanner API doesn't provide - # easy mechanisms to detect when only a single item - # is returned or many, yet mixing results that - # are for .fetchone() with those that would result in - # many items returns a RuntimeError if .fetchone() is - # invoked and vice versa. - self._result_set = res - # Read the first element so that the StreamedResultSet can - # return the metadata after a DQL statement. See issue #155. - self._itr = PeekIterator(self._result_set) - # Unfortunately, Spanner doesn't seem to send back - # information about the number of rows available. - self._row_count = _UNSET_COUNT - - def __enter__(self): - return self - - def __exit__(self, etype, value, traceback): - self.close() - - def __next__(self): - if self._itr is None: - raise ProgrammingError("no results to return") - return next(self._itr) - - def __iter__(self): - if self._itr is None: - raise ProgrammingError("no results to return") - return self._itr - - def list_tables(self): - return self.run_sql_in_snapshot(_helpers.SQL_LIST_TABLES) - - def run_sql_in_snapshot(self, sql, params=None, param_types=None): - # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions - # hence this method exists to circumvent that limit. - self.connection.run_prior_DDL_statements() - - with self.connection.database.snapshot() as snapshot: - res = snapshot.execute_sql( - sql, params=params, param_types=param_types - ) - return list(res) - - def get_table_column_schema(self, table_name): - rows = self.run_sql_in_snapshot( - sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, - params={"table_name": table_name}, - param_types={"table_name": spanner.param_types.STRING}, - ) - - column_details = {} - for column_name, is_nullable, spanner_type in rows: - column_details[column_name] = ColumnDetails( - null_ok=is_nullable == "YES", spanner_type=spanner_type - ) - return column_details diff --git a/google/cloud/spanner_dbapi/exceptions.py b/google/cloud/spanner_dbapi/exceptions.py deleted file mode 100644 index b21be2c949..0000000000 --- a/google/cloud/spanner_dbapi/exceptions.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Spanner DB API exceptions.""" - - -class Warning(Exception): - """Important DB API warning.""" - - pass - - -class Error(Exception): - """The base class for all the DB API exceptions. - - Does not include :class:`Warning`. - """ - - pass - - -class InterfaceError(Error): - """ - Error related to the database interface - rather than the database itself. - """ - - pass - - -class DatabaseError(Error): - """Error related to the database.""" - - pass - - -class DataError(DatabaseError): - """ - Error due to problems with the processed data like - division by zero, numeric value out of range, etc. - """ - - pass - - -class OperationalError(DatabaseError): - """ - Error related to the database's operation, e.g. an - unexpected disconnect, the data source name is not - found, a transaction could not be processed, a - memory allocation error, etc. - """ - - pass - - -class IntegrityError(DatabaseError): - """ - Error for cases of relational integrity of the database - is affected, e.g. a foreign key check fails. - """ - - pass - - -class InternalError(DatabaseError): - """ - Internal database error, e.g. the cursor is not valid - anymore, the transaction is out of sync, etc. - """ - - pass - - -class ProgrammingError(DatabaseError): - """ - Programming error, e.g. table not found or already - exists, syntax error in the SQL statement, wrong - number of parameters specified, etc. - """ - - pass - - -class NotSupportedError(DatabaseError): - """ - Error for case of a method or database API not - supported by the database was used. - """ - - pass diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py deleted file mode 100644 index 0e69dbc0ca..0000000000 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ /dev/null @@ -1,545 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"SQL parsing and classification utils." - -import datetime -import decimal -import re -from functools import reduce - -import sqlparse -from google.cloud import spanner_v1 as spanner - -from .exceptions import Error, ProgrammingError -from .parser import parse_values -from .types import DateStr, TimestampStr -from .utils import sanitize_literals_for_upload - -TYPES_MAP = { - bool: spanner.param_types.BOOL, - bytes: spanner.param_types.BYTES, - str: spanner.param_types.STRING, - int: spanner.param_types.INT64, - float: spanner.param_types.FLOAT64, - datetime.datetime: spanner.param_types.TIMESTAMP, - datetime.date: spanner.param_types.DATE, - DateStr: spanner.param_types.DATE, - TimestampStr: spanner.param_types.TIMESTAMP, -} - -SPANNER_RESERVED_KEYWORDS = { - "ALL", - "AND", - "ANY", - "ARRAY", - "AS", - "ASC", - "ASSERT_ROWS_MODIFIED", - "AT", - "BETWEEN", - "BY", - "CASE", - "CAST", - "COLLATE", - "CONTAINS", - "CREATE", - "CROSS", - "CUBE", - "CURRENT", - "DEFAULT", - "DEFINE", - "DESC", - "DISTINCT", - "DROP", - "ELSE", - "END", - "ENUM", - "ESCAPE", - "EXCEPT", - "EXCLUDE", - "EXISTS", - "EXTRACT", - "FALSE", - "FETCH", - "FOLLOWING", - "FOR", - "FROM", - "FULL", - "GROUP", - "GROUPING", - "GROUPS", - "HASH", - "HAVING", - "IF", - "IGNORE", - "IN", - "INNER", - "INTERSECT", - "INTERVAL", - "INTO", - "IS", - "JOIN", - "LATERAL", - "LEFT", - "LIKE", - "LIMIT", - "LOOKUP", - "MERGE", - "NATURAL", - "NEW", - "NO", - "NOT", - "NULL", - "NULLS", - "OF", - "ON", - "OR", - "ORDER", - "OUTER", - "OVER", - "PARTITION", - "PRECEDING", - "PROTO", - "RANGE", - "RECURSIVE", - "RESPECT", - "RIGHT", - "ROLLUP", - "ROWS", - "SELECT", - "SET", - "SOME", - "STRUCT", - "TABLESAMPLE", - "THEN", - "TO", - "TREAT", - "TRUE", - "UNBOUNDED", - "UNION", - "UNNEST", - "USING", - "WHEN", - "WHERE", - "WINDOW", - "WITH", - "WITHIN", -} - -STMT_DDL = "DDL" -STMT_NON_UPDATING = "NON_UPDATING" -STMT_UPDATING = "UPDATING" -STMT_INSERT = "INSERT" - -# Heuristic for identifying statements that don't need to be run as updates. -RE_NON_UPDATE = re.compile(r"^\s*(SELECT)", re.IGNORECASE) - -RE_WITH = re.compile(r"^\s*(WITH)", re.IGNORECASE) - -# DDL statements follow -# https://cloud.google.com/spanner/docs/data-definition-language -RE_DDL = re.compile(r"^\s*(CREATE|ALTER|DROP)", re.IGNORECASE | re.DOTALL) - -RE_IS_INSERT = re.compile(r"^\s*(INSERT)", re.IGNORECASE | re.DOTALL) - -RE_INSERT = re.compile( - # Only match the `INSERT INTO (columns...) - # otherwise the rest of the statement could be a complex - # operation. - r"^\s*INSERT INTO (?P[^\s\(\)]+)\s*\((?P[^\(\)]+)\)", - re.IGNORECASE | re.DOTALL, -) - -RE_VALUES_TILL_END = re.compile(r"VALUES\s*\(.+$", re.IGNORECASE | re.DOTALL) - -RE_VALUES_PYFORMAT = re.compile( - # To match: (%s, %s,....%s) - r"(\(\s*%s[^\(\)]+\))", - re.DOTALL, -) - -RE_PYFORMAT = re.compile(r"(%s|%\([^\(\)]+\)s)+", re.DOTALL) - - -def classify_stmt(query): - """Determine SQL query type. - - :type query: :class:`str` - :param query: SQL query. - - :rtype: :class:`str` - :returns: Query type name. - """ - if RE_DDL.match(query): - return STMT_DDL - - if RE_IS_INSERT.match(query): - return STMT_INSERT - - if RE_NON_UPDATE.match(query) or RE_WITH.match(query): - # As of 13-March-2020, Cloud Spanner only supports WITH for DQL - # statements and doesn't yet support WITH for DML statements. - return STMT_NON_UPDATING - - return STMT_UPDATING - - -def parse_insert(insert_sql, params): - """ - Parse an INSERT statement an generate a list of tuples of the form: - [ - (SQL, params_per_row1), - (SQL, params_per_row2), - (SQL, params_per_row3), - ... - ] - - There are 4 variants of an INSERT statement: - a) INSERT INTO (columns...) VALUES (): no params - b) INSERT INTO
(columns...) SELECT_STMT: no params - c) INSERT INTO
(columns...) VALUES (%s,...): with params - d) INSERT INTO
(columns...) VALUES (%s,.....) with params and expressions - - Thus given each of the forms, it will produce a dictionary describing - how to upload the contents to Cloud Spanner: - Case a) - SQL: INSERT INTO T (f1, f2) VALUES (1, 2) - it produces: - { - 'sql_params_list': [ - ('INSERT INTO T (f1, f2) VALUES (1, 2)', None), - ], - } - - Case b) - SQL: 'INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', - it produces: - { - 'sql_params_list': [ - ('INSERT INTO T (s, c) SELECT st, zc FROM cus ORDER BY fn, ln', None), - ] - } - - Case c) - SQL: INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s) - Params: ['a', 'b', 'c', 'd'] - it produces: - { - 'homogenous': True, - 'table': 'T', - 'columns': ['f1', 'f2'], - 'values': [('a', 'b',), ('c', 'd',)], - } - - Case d) - SQL: INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s)), (UPPER(%s), %s) - Params: ['a', 'b', 'c', 'd'] - it produces: - { - 'sql_params_list': [ - ('INSERT INTO T (f1, f2) VALUES (%s, LOWER(%s))', ('a', 'b',)) - ('INSERT INTO T (f1, f2) VALUES (UPPER(%s), %s)', ('c', 'd',)) - ], - } - """ # noqa - match = RE_INSERT.search(insert_sql) - - if not match: - raise ProgrammingError( - "Could not parse an INSERT statement from %s" % insert_sql - ) - - after_values_sql = RE_VALUES_TILL_END.findall(insert_sql) - if not after_values_sql: - # Case b) - insert_sql = sanitize_literals_for_upload(insert_sql) - return {"sql_params_list": [(insert_sql, None)]} - - if not params: - # Case a) perhaps? - # Check if any %s exists. - - # pyformat_str_count = after_values_sql.count("%s") - # if pyformat_str_count > 0: - # raise ProgrammingError( - # 'no params yet there are %d "%%s" tokens' % pyformat_str_count - # ) - for item in after_values_sql: - if item.count("%s") > 0: - raise ProgrammingError( - 'no params yet there are %d "%%s" tokens' - % item.count("%s") - ) - - insert_sql = sanitize_literals_for_upload(insert_sql) - # Confirmed case of: - # SQL: INSERT INTO T (a1, a2) VALUES (1, 2) - # Params: None - return {"sql_params_list": [(insert_sql, None)]} - - values_str = after_values_sql[0] - _, values = parse_values(values_str) - - if values.homogenous(): - # Case c) - - columns = [mi.strip(" `") for mi in match.group("columns").split(",")] - sql_params_list = [] - insert_sql_preamble = "INSERT INTO %s (%s) VALUES %s" % ( - match.group("table_name"), - match.group("columns"), - values.argv[0], - ) - values_pyformat = [str(arg) for arg in values.argv] - rows_list = rows_for_insert_or_update(columns, params, values_pyformat) - insert_sql_preamble = sanitize_literals_for_upload(insert_sql_preamble) - for row in rows_list: - sql_params_list.append((insert_sql_preamble, row)) - - return {"sql_params_list": sql_params_list} - - # Case d) - # insert_sql is of the form: - # INSERT INTO T(c1, c2) VALUES (%s, %s), (%s, LOWER(%s)) - - # Sanity check: - # length(all_args) == len(params) - args_len = reduce(lambda a, b: a + b, [len(arg) for arg in values.argv]) - if args_len != len(params): - raise ProgrammingError( - "Invalid length: VALUES(...) len: %d != len(params): %d" - % (args_len, len(params)) - ) - - trim_index = insert_sql.find(values_str) - before_values_sql = insert_sql[:trim_index] - - sql_param_tuples = [] - for token_arg in values.argv: - row_sql = before_values_sql + " VALUES%s" % token_arg - row_sql = sanitize_literals_for_upload(row_sql) - row_params, params = ( - tuple(params[0 : len(token_arg)]), - params[len(token_arg) :], - ) - sql_param_tuples.append((row_sql, row_params)) - - return {"sql_params_list": sql_param_tuples} - - -def rows_for_insert_or_update(columns, params, pyformat_args=None): - """ - Create a tupled list of params to be used as a single value per - value that inserted from a statement such as - SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s)' - Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)] - Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9] - - We'll have to convert both params types into: - Params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] - """ # noqa - - if not pyformat_args: - # This is the case where we have for example: - # SQL: 'INSERT INTO t (f1, f2, f3)' - # Params A: [(1, 2, 3), (4, 5, 6), (7, 8, 9)] - # Params B: [1, 2, 3, 4, 5, 6, 7, 8, 9] - # - # We'll have to convert both params types into: - # [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] - contains_all_list_or_tuples = True - for param in params: - if not (isinstance(param, list) or isinstance(param, tuple)): - contains_all_list_or_tuples = False - break - - if contains_all_list_or_tuples: - # The case with Params A: [(1, 2, 3), (4, 5, 6)] - # Ensure that each param's length == len(columns) - columns_len = len(columns) - for param in params: - if columns_len != len(param): - raise Error( - "\nlen(`%s`)=%d\n!=\ncolum_len(`%s`)=%d" - % (param, len(param), columns, columns_len) - ) - return params - else: - # The case with Params B: [1, 2, 3] - # Insert statements' params are only passed as tuples or lists, - # yet for do_execute_update, we've got to pass in list of list. - # https://googleapis.dev/python/spanner/latest/transaction-api.html\ - # #google.cloud.spanner_v1.transaction.Transaction.insert - n_stride = len(columns) - else: - # This is the case where we have for example: - # SQL: 'INSERT INTO t (f1, f2, f3) VALUES (%s, %s, %s), - # (%s, %s, %s), (%s, %s, %s)' - # Params: [1, 2, 3, 4, 5, 6, 7, 8, 9] - # which should become - # Columns: (f1, f2, f3) - # new_params: [(1, 2, 3,), (4, 5, 6,), (7, 8, 9,)] - - # Sanity check 1: all the pyformat_values should have the exact same - # length. - first, rest = pyformat_args[0], pyformat_args[1:] - n_stride = first.count("%s") - for pyfmt_value in rest: - n = pyfmt_value.count("%s") - if n_stride != n: - raise Error( - "\nlen(`%s`)=%d\n!=\nlen(`%s`)=%d" - % (first, n_stride, pyfmt_value, n) - ) - - # Sanity check 2: len(params) MUST be a multiple of n_stride aka - # len(count of %s). - # so that we can properly group for example: - # Given pyformat args: - # (%s, %s, %s) - # Params: - # [1, 2, 3, 4, 5, 6, 7, 8, 9] - # into - # [(1, 2, 3), (4, 5, 6), (7, 8, 9)] - if (len(params) % n_stride) != 0: - raise ProgrammingError( - "Invalid length: len(params)=%d MUST be a multiple of " - "len(pyformat_args)=%d" % (len(params), n_stride) - ) - - # Now chop up the strides. - strides = [] - for step in range(0, len(params), n_stride): - stride = tuple(params[step : step + n_stride :]) - strides.append(stride) - - return strides - - -def sql_pyformat_args_to_spanner(sql, params): - """ - Transform pyformat set SQL to named arguments for Cloud Spanner. - It will also unescape previously escaped format specifiers - like %%s to %s. - For example: - SQL: 'SELECT * from t where f1=%s, f2=%s, f3=%s' - Params: ('a', 23, '888***') - becomes: - SQL: 'SELECT * from t where f1=@a0, f2=@a1, f3=@a2' - Params: {'a0': 'a', 'a1': 23, 'a2': '888***'} - - OR - SQL: 'SELECT * from t where f1=%(f1)s, f2=%(f2)s, f3=%(f3)s' - Params: {'f1': 'a', 'f2': 23, 'f3': '888***', 'extra': 'aye') - becomes: - SQL: 'SELECT * from t where f1=@a0, f2=@a1, f3=@a2' - Params: {'a0': 'a', 'a1': 23, 'a2': '888***'} - """ - if not params: - return sanitize_literals_for_upload(sql), params - - found_pyformat_placeholders = RE_PYFORMAT.findall(sql) - params_is_dict = isinstance(params, dict) - - if params_is_dict: - if not found_pyformat_placeholders: - return sanitize_literals_for_upload(sql), params - else: - n_params = len(params) if params else 0 - n_matches = len(found_pyformat_placeholders) - if n_matches != n_params: - raise Error( - "pyformat_args mismatch\ngot %d args from %s\n" - "want %d args in %s" - % (n_matches, found_pyformat_placeholders, n_params, params) - ) - - named_args = {} - # We've now got for example: - # Case a) Params is a non-dict - # SQL: 'SELECT * from t where f1=%s, f2=%s, f3=%s' - # Params: ('a', 23, '888***') - # Case b) Params is a dict and the matches are %(value)s' - for i, pyfmt in enumerate(found_pyformat_placeholders): - key = "a%d" % i - sql = sql.replace(pyfmt, "@" + key, 1) - if params_is_dict: - # The '%(key)s' case, so interpolate it. - resolved_value = pyfmt % params - named_args[key] = resolved_value - else: - named_args[key] = cast_for_spanner(params[i]) - - return sanitize_literals_for_upload(sql), named_args - - -def cast_for_spanner(value): - """Convert the param to its Cloud Spanner equivalent type. - - :type value: Any - :param value: Value to convert to a Cloud Spanner type. - - :rtype: Any - :returns: Value converted to a Cloud Spanner type. - """ - if isinstance(value, decimal.Decimal): - return float(value) - return value - - -def get_param_types(params): - """Determine Cloud Spanner types for the given parameters. - - :type params: :class:`dict` - :param params: Parameters requiring to find Cloud Spanner types. - - :rtype: :class:`dict` - :returns: The types index for the given parameters. - """ - if params is None: - return - - param_types = {} - - for key, value in params.items(): - type_ = type(value) - if type_ in TYPES_MAP: - param_types[key] = TYPES_MAP[type_] - - return param_types - - -def ensure_where_clause(sql): - """ - Raise unless `sql` includes a WHERE clause. - - :type sql: str - :param sql: SQL statement to check. - """ - if not any( - isinstance(token, sqlparse.sql.Where) - for token in sqlparse.parse(sql)[0] - ): - raise ProgrammingError( - "Cloud Spanner requires a WHERE clause in UPDATE and DELETE statements" - ) - - -def escape_name(name): - """ - Apply backticks to the name that either contain '-' or - ' ', or is a Cloud Spanner's reserved keyword. - - :type name: :class:`str` - :param name: Name to escape. - - :rtype: :class:`str` - :returns: Name escaped if it has to be escaped. - """ - if "-" in name or " " in name or name.upper() in SPANNER_RESERVED_KEYWORDS: - return "`" + name + "`" - return name diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py deleted file mode 100644 index 2fc0156b57..0000000000 --- a/google/cloud/spanner_dbapi/parser.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -""" -Grammar for parsing VALUES: - VALUES := `VALUES(` + ARGS + `)` - ARGS := [EXPR,]*EXPR - EXPR := TERMINAL / FUNC - TERMINAL := `%s` - FUNC := alphanum + `(` + ARGS + `)` - alphanum := (a-zA-Z_)[0-9a-ZA-Z_]* - -thus given: - statement: 'VALUES (%s, %s), (%s, LOWER(UPPER(%s))) , (%s)' - It'll parse: - VALUES - |- ARGS - |- (TERMINAL, TERMINAL) - |- (TERMINAL, FUNC - |- FUNC - |- (TERMINAL) - |- (TERMINAL) -""" - -from .exceptions import ProgrammingError - -ARGS = "ARGS" -EXPR = "EXPR" -FUNC = "FUNC" -VALUES = "VALUES" - - -class func(object): - def __init__(self, func_name, args): - self.name = func_name - self.args = args - - def __str__(self): - return "%s%s" % (self.name, self.args) - - def __repr__(self): - return self.__str__() - - def __eq__(self, other): - if type(self) != type(other): - return False - if self.name != other.name: - return False - if not isinstance(other.args, type(self.args)): - return False - if len(self.args) != len(other.args): - return False - return self.args == other.args - - def __len__(self): - return len(self.args) - - -class terminal(str): - """ - terminal represents the unit symbol that can be part of a SQL values clause. - """ - - pass - - -class a_args(object): - def __init__(self, argv): - self.argv = argv - - def __str__(self): - return "(" + ", ".join([str(arg) for arg in self.argv]) + ")" - - def __repr__(self): - return self.__str__() - - def has_expr(self): - return any( - [token for token in self.argv if not isinstance(token, terminal)] - ) - - def __len__(self): - return len(self.argv) - - def __eq__(self, other): - if type(self) != type(other): - return False - - if len(self) != len(other): - return False - - for i, item in enumerate(self): - if item != other[i]: - return False - - return True - - def __getitem__(self, index): - return self.argv[index] - - def homogenous(self): - """ - Return True if all the arguments are pyformat - args and have the same number of arguments. - """ - if not self._is_equal_length(): - return False - - for arg in self.argv: - if isinstance(arg, terminal): - continue - elif isinstance(arg, a_args): - if not arg.homogenous(): - return False - else: - return False - return True - - def _is_equal_length(self): - """ - Return False if all the arguments have the same length. - """ - if len(self) == 0: - return True - - arg0_len = len(self.argv[0]) - for arg in self.argv[1:]: - if len(arg) != arg0_len: - return False - - return True - - -class values(a_args): - def __str__(self): - return "VALUES%s" % super().__str__() - - -def parse_values(stmt): - return expect(stmt, VALUES) - - -pyfmt_str = terminal("%s") - - -def expect(word, token): - word = word.strip() - if token == VALUES: - if not word.startswith("VALUES"): - raise ProgrammingError( - "VALUES: `%s` does not start with VALUES" % word - ) - word = word[len("VALUES") :].lstrip() - - all_args = [] - while word: - word = word.strip() - - word, arg = expect(word, ARGS) - all_args.append(arg) - word = word.strip() - - if word and not word.startswith(","): - raise ProgrammingError( - "VALUES: expected `,` got %s in %s" % (word[0], word) - ) - word = word[1:] - return "", values(all_args) - - elif token == FUNC: - begins_with_letter = word and (word[0].isalpha() or word[0] == "_") - if not begins_with_letter: - raise ProgrammingError( - "FUNC: `%s` does not begin with `a-zA-z` nor a `_`" % word - ) - - rest = word[1:] - end = 0 - for ch in rest: - if ch.isalnum() or ch == "_": - end += 1 - else: - break - - func_name, rest = word[: end + 1], word[end + 1 :].strip() - - word, args = expect(rest, ARGS) - return word, func(func_name, args) - - elif token == ARGS: - # The form should be: - # (%s) - # (%s, %s...) - # (FUNC, %s...) - # (%s, %s...) - if not (word and word.startswith("(")): - raise ProgrammingError( - "ARGS: supposed to begin with `(` in `%s`" % word - ) - - word = word[1:] - - terms = [] - while True: - word = word.strip() - if not word or word.startswith(")"): - break - - if word == "%s": - terms.append(pyfmt_str) - word = "" - elif not word.startswith("%s"): - word, parsed = expect(word, FUNC) - terms.append(parsed) - else: - terms.append(pyfmt_str) - word = word[2:].strip() - - if word.startswith(","): - word = word[1:] - - if not (word and word.startswith(")")): - raise ProgrammingError( - "ARGS: supposed to end with `)` in `%s`" % word - ) - - word = word[1:] - return word, a_args(terms) - - elif token == EXPR: - if word == "%s": - # Terminal symbol. - return "", pyfmt_str - - # Otherwise we expect a function. - return expect(word, FUNC) - - raise ProgrammingError("Unknown token `%s`" % token) - - -def as_values(values_stmt): - _, _values = parse_values(values_stmt) - return _values diff --git a/google/cloud/spanner_dbapi/types.py b/google/cloud/spanner_dbapi/types.py deleted file mode 100644 index 8c6bd27577..0000000000 --- a/google/cloud/spanner_dbapi/types.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Implementation of the type objects and constructors according to the - PEP-0249 specification. - - See - https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors -""" - -import datetime -import time -from base64 import b64encode - - -def _date_from_ticks(ticks): - """Based on PEP-249 Implementation Hints for Module Authors: - - https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors - """ - return Date(*time.localtime(ticks)[:3]) - - -def _time_from_ticks(ticks): - """Based on PEP-249 Implementation Hints for Module Authors: - - https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors - """ - return Time(*time.localtime(ticks)[3:6]) - - -def _timestamp_from_ticks(ticks): - """Based on PEP-249 Implementation Hints for Module Authors: - - https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors - """ - return Timestamp(*time.localtime(ticks)[:6]) - - -class _DBAPITypeObject(object): - """Implementation of a helper class used for type comparison among similar - but possibly different types. - - See - https://www.python.org/dev/peps/pep-0249/#implementation-hints-for-module-authors - """ - - def __init__(self, *values): - self.values = values - - def __eq__(self, other): - return other in self.values - - -Date = datetime.date -Time = datetime.time -Timestamp = datetime.datetime -DateFromTicks = _date_from_ticks -TimeFromTicks = _time_from_ticks -TimestampFromTicks = _timestamp_from_ticks -Binary = b64encode - -STRING = "STRING" -BINARY = _DBAPITypeObject("TYPE_CODE_UNSPECIFIED", "BYTES", "ARRAY", "STRUCT") -NUMBER = _DBAPITypeObject("BOOL", "INT64", "FLOAT64", "NUMERIC") -DATETIME = _DBAPITypeObject("TIMESTAMP", "DATE") -ROWID = "STRING" - - -class TimestampStr(str): - """[inherited from the alpha release] - - TODO: Decide whether this class is necessary - - TimestampStr exists so that we can purposefully format types as timestamps - compatible with Cloud Spanner's TIMESTAMP type, but right before making - queries, it'll help differentiate between normal strings and the case of - types that should be TIMESTAMP. - """ - - pass - - -class DateStr(str): - """[inherited from the alpha release] - - TODO: Decide whether this class is necessary - - DateStr is a sentinel type to help format Django dates as - compatible with Cloud Spanner's DATE type, but right before making - queries, it'll help differentiate between normal strings and the case of - types that should be DATE. - """ - - pass diff --git a/google/cloud/spanner_dbapi/utils.py b/google/cloud/spanner_dbapi/utils.py deleted file mode 100644 index f4769e80a4..0000000000 --- a/google/cloud/spanner_dbapi/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import re - - -class PeekIterator: - """ - PeekIterator peeks at the first element out of an iterator - for the sake of operations like auto-population of fields on reading - the first element. - If next's result is an instance of list, it'll be converted into a tuple - to conform with DBAPI v2's sequence expectations. - """ - - def __init__(self, source): - itr_src = iter(source) - - self.__iters = [] - self.__index = 0 - - try: - head = next(itr_src) - # Restitch and prepare to read from multiple iterators. - self.__iters = [iter(itr) for itr in [[head], itr_src]] - except StopIteration: - pass - - def __next__(self): - if self.__index >= len(self.__iters): - raise StopIteration - - iterator = self.__iters[self.__index] - try: - head = next(iterator) - except StopIteration: - # That iterator has been exhausted, try with the next one. - self.__index += 1 - return self.__next__() - else: - return tuple(head) if isinstance(head, list) else head - - def __iter__(self): - return self - - -re_UNICODE_POINTS = re.compile(r"([^\s]*[\u0080-\uFFFF]+[^\s]*)") - - -def backtick_unicode(sql): - matches = list(re_UNICODE_POINTS.finditer(sql)) - if not matches: - return sql - - segments = [] - - last_end = 0 - for match in matches: - start, end = match.span() - if sql[start] != "`" and sql[end - 1] != "`": - segments.append(sql[last_end:start] + "`" + sql[start:end] + "`") - else: - segments.append(sql[last_end:end]) - - last_end = end - - return "".join(segments) - - -def sanitize_literals_for_upload(s): - """ - Convert literals in s, to be fit for consumption by Cloud Spanner. - 1. Convert %% (escaped percent literals) to %. Percent signs must be escaped when - values like %s are used as SQL parameter placeholders but Spanner's query language - uses placeholders like @a0 and doesn't expect percent signs to be escaped. - 2. Quote words containing non-ASCII, with backticks, for example föö to `föö`. - """ - return backtick_unicode(s.replace("%%", "%")) diff --git a/google/cloud/spanner_dbapi/version.py b/google/cloud/spanner_dbapi/version.py deleted file mode 100644 index 88d8f7cdaf..0000000000 --- a/google/cloud/spanner_dbapi/version.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import platform - -PY_VERSION = platform.python_version() -VERSION = "2.2.0a1" -DEFAULT_USER_AGENT = "django_spanner/" + VERSION diff --git a/noxfile.py b/noxfile.py index f5e2ab844c..53217c5e26 100644 --- a/noxfile.py +++ b/noxfile.py @@ -16,9 +16,7 @@ BLACK_VERSION = "black==19.10b0" BLACK_PATHS = [ - "django_spanner", "docs", - "google", "tests", "noxfile.py", "setup.py", @@ -34,7 +32,7 @@ def lint(session): """ session.install("flake8", BLACK_VERSION) session.run("black", "--check", *BLACK_PATHS) - session.run("flake8", "django_spanner", "google", "tests") + session.run("flake8", "django_spanner", "tests") @nox.session(python="3.8") @@ -67,71 +65,10 @@ def default(session): # Run py.test against the unit tests. session.run( - "py.test", - "--quiet", - # "--cov=django_spanner", - "--cov=google.cloud", - "--cov=tests.unit", - "--cov-append", - "--cov-config=.coveragerc", - "--cov-report=", - "--cov-fail-under=90", - os.path.join("tests", "unit"), - *session.posargs + "py.test", "--quiet", os.path.join("tests", "unit"), *session.posargs ) -@nox.session(python=["3.6", "3.7", "3.8"]) -def unit(session): - """Run the unit test suite.""" - default(session) - - -@nox.session(python="3.8") -def system(session): - """Run the system test suite.""" - system_test_path = os.path.join("tests", "system.py") - system_test_folder_path = os.path.join("tests", "system") - - # Sanity check: Only run tests if the environment variable is set. - if not os.environ.get( - "GOOGLE_APPLICATION_CREDENTIALS", "" - ) and not os.environ.get("SPANNER_EMULATOR_HOST", ""): - session.skip("Credentials must be set via environment variable") - - system_test_exists = os.path.exists(system_test_path) - system_test_folder_exists = os.path.exists(system_test_folder_path) - - # Sanity check: only run tests if found. - if not system_test_exists and not system_test_folder_exists: - session.skip("System tests were not found") - - # Install all test dependencies, then install this package into the - # virtualenv's dist-packages. - session.install("mock", "pytest", "google-cloud-testutils") - session.install("-e", ".") - - # Run py.test against the system tests. - if system_test_exists: - session.run("py.test", "--quiet", system_test_path, *session.posargs) - if system_test_folder_exists: - session.run( - "py.test", "--quiet", system_test_folder_path, *session.posargs - ) - - -@nox.session(python="3.8") -def cover(session): - """Run the final coverage report. - - This outputs the coverage report aggregating coverage from the unit - test runs (not system test runs), and then erases coverage data. - """ - session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=40") - session.run("coverage", "erase") - - @nox.session(python="3.8") def docs(session): """Build the docs for this library.""" diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py deleted file mode 100644 index fb4d89c373..0000000000 --- a/tests/spanner_dbapi/test_connect.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""connect() module function unit tests.""" - -import unittest -from unittest import mock - -import google.auth.credentials -from google.api_core.gapic_v1.client_info import ClientInfo -from google.cloud.spanner_dbapi import connect, Connection -from google.cloud.spanner_v1.pool import FixedSizePool - - -def _make_credentials(): - class _CredentialsWithScopes( - google.auth.credentials.Credentials, google.auth.credentials.Scoped - ): - pass - - return mock.Mock(spec=_CredentialsWithScopes) - - -class Test_connect(unittest.TestCase): - def test_connect(self): - PROJECT = "test-project" - USER_AGENT = "user-agent" - CREDENTIALS = _make_credentials() - CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) - - with mock.patch( - "google.cloud.spanner_dbapi.spanner_v1.Client" - ) as client_mock: - with mock.patch( - "google.cloud.spanner_dbapi.google_client_info", - return_value=CLIENT_INFO, - ) as client_info_mock: - - connection = connect( - "test-instance", - "test-database", - PROJECT, - CREDENTIALS, - user_agent=USER_AGENT, - ) - - self.assertIsInstance(connection, Connection) - client_info_mock.assert_called_once_with(USER_AGENT) - - client_mock.assert_called_once_with( - project=PROJECT, - credentials=CREDENTIALS, - client_info=CLIENT_INFO, - ) - - def test_instance_not_found(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=False, - ) as exists_mock: - - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - exists_mock.assert_called_once_with() - - def test_database_not_found(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=False, - ) as exists_mock: - - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - exists_mock.assert_called_once_with() - - def test_connect_instance_id(self): - INSTANCE = "test-instance" - - with mock.patch( - "google.cloud.spanner_v1.client.Client.instance" - ) as instance_mock: - connection = connect(INSTANCE, "test-database") - - instance_mock.assert_called_once_with(INSTANCE) - - self.assertIsInstance(connection, Connection) - - def test_connect_database_id(self): - DATABASE = "test-database" - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connection = connect("test-instance", DATABASE) - - database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) - - self.assertIsInstance(connection, Connection) - - def test_default_sessions_pool(self): - with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - self.assertIsNotNone(connection.database._pool) - - def test_sessions_pool(self): - database_id = "test-database" - pool = FixedSizePool() - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connect("test-instance", database_id, pool=pool) - database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py deleted file mode 100644 index 24260de12e..0000000000 --- a/tests/spanner_dbapi/test_connection.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Connection() class unit tests.""" - -import unittest -from unittest import mock - -# import google.cloud.spanner_dbapi.exceptions as dbapi_exceptions - -from google.cloud.spanner_dbapi import Connection, InterfaceError -from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING -from google.cloud.spanner_v1.database import Database -from google.cloud.spanner_v1.instance import Instance - - -class TestConnection(unittest.TestCase): - instance_name = "instance-name" - database_name = "database-name" - - def _make_connection(self): - # we don't need real Client object to test the constructor - instance = Instance(self.instance_name, client=None) - database = instance.database(self.database_name) - return Connection(instance, database) - - def test_ctor(self): - connection = self._make_connection() - - self.assertIsInstance(connection.instance, Instance) - self.assertEqual(connection.instance.instance_id, self.instance_name) - - self.assertIsInstance(connection.database, Database) - self.assertEqual(connection.database.database_id, self.database_name) - - self.assertFalse(connection.is_closed) - - def test_close(self): - connection = self._make_connection() - - self.assertFalse(connection.is_closed) - connection.close() - self.assertTrue(connection.is_closed) - - with self.assertRaises(InterfaceError): - connection.cursor() - - @mock.patch("warnings.warn") - def test_transaction_autocommit_warnings(self, warn_mock): - connection = self._make_connection() - connection.autocommit = True - - connection.commit() - warn_mock.assert_called_with( - AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 - ) - connection.rollback() - warn_mock.assert_called_with( - AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 - ) - - def test_database_property(self): - connection = self._make_connection() - self.assertIsInstance(connection.database, Database) - self.assertEqual(connection.database, connection._database) - - with self.assertRaises(AttributeError): - connection.database = None - - def test_instance_property(self): - connection = self._make_connection() - self.assertIsInstance(connection.instance, Instance) - self.assertEqual(connection.instance, connection._instance) - - with self.assertRaises(AttributeError): - connection.instance = None diff --git a/tests/system/test_system.py b/tests/system/test_system.py deleted file mode 100644 index f3ee345e15..0000000000 --- a/tests/system/test_system.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import unittest -import os - -from google.api_core import exceptions - -from google.cloud.spanner import Client -from google.cloud.spanner import BurstyPool -from google.cloud.spanner_dbapi.connection import Connection - -from test_utils.retry import RetryErrors -from test_utils.system import unique_resource_id - - -CREATE_INSTANCE = ( - os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None -) -USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None - -if CREATE_INSTANCE: - INSTANCE_ID = "google-cloud" + unique_resource_id("-") -else: - INSTANCE_ID = os.environ.get( - "GOOGLE_CLOUD_TESTS_SPANNER_INSTANCE", "google-cloud-python-systest" - ) -EXISTING_INSTANCES = [] - -DDL_STATEMENTS = ( - """CREATE TABLE contacts ( - contact_id INT64, - first_name STRING(1024), - last_name STRING(1024), - email STRING(1024) - ) - PRIMARY KEY (contact_id)""", -) - - -class Config(object): - """Run-time configuration to be modified at set-up. - - This is a mutable stand-in to allow test set-up to modify - global state. - """ - - CLIENT = None - INSTANCE_CONFIG = None - INSTANCE = None - - -def _list_instances(): - return list(Config.CLIENT.list_instances()) - - -def setUpModule(): - if USE_EMULATOR: - from google.auth.credentials import AnonymousCredentials - - emulator_project = os.getenv("GCLOUD_PROJECT", "emulator-test-project") - Config.CLIENT = Client( - project=emulator_project, credentials=AnonymousCredentials() - ) - else: - Config.CLIENT = Client() - - retry = RetryErrors(exceptions.ServiceUnavailable) - - configs = list(retry(Config.CLIENT.list_instance_configs)()) - - instances = retry(_list_instances)() - EXISTING_INSTANCES[:] = instances - - if CREATE_INSTANCE: - if not USE_EMULATOR: - # Defend against back-end returning configs for regions we aren't - # actually allowed to use. - configs = [config for config in configs if "-us-" in config.name] - - if not configs: - raise ValueError("List instance configs failed in module set up.") - - Config.INSTANCE_CONFIG = configs[0] - config_name = configs[0].name - - Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID, config_name) - created_op = Config.INSTANCE.create() - created_op.result(30) # block until completion - else: - Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID) - Config.INSTANCE.reload() - - -def tearDownModule(): - """Delete the test instance, if it was created.""" - if CREATE_INSTANCE: - Config.INSTANCE.delete() - - -class TestTransactionsManagement(unittest.TestCase): - """Transactions management support tests.""" - - DATABASE_NAME = "db-api-transactions-management" - - @classmethod - def setUpClass(cls): - """Create a test database.""" - cls._db = Config.INSTANCE.database( - cls.DATABASE_NAME, - ddl_statements=DDL_STATEMENTS, - pool=BurstyPool(labels={"testcase": "database_api"}), - ) - cls._db.create().result(30) # raises on failure / timeout. - - @classmethod - def tearDownClass(cls): - """Delete the test database.""" - cls._db.drop() - - def tearDown(self): - """Clear the test table after every test.""" - self._db.run_in_transaction(clear_table) - - def test_commit(self): - """Test committing a transaction with several statements.""" - want_row = ( - 1, - "updated-first-name", - "last-name", - "test.email_updated@domen.ru", - ) - # connect to the test database - conn = Connection(Config.INSTANCE, self._db) - cursor = conn.cursor() - - # execute several DML statements within one transaction - cursor.execute( - """ -INSERT INTO contacts (contact_id, first_name, last_name, email) -VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - cursor.execute( - """ -UPDATE contacts -SET first_name = 'updated-first-name' -WHERE first_name = 'first-name' -""" - ) - cursor.execute( - """ -UPDATE contacts -SET email = 'test.email_updated@domen.ru' -WHERE email = 'test.email@domen.ru' -""" - ) - conn.commit() - - # read the resulting data from the database - cursor.execute("SELECT * FROM contacts") - got_rows = cursor.fetchall() - conn.commit() - - self.assertEqual(got_rows, [want_row]) - - cursor.close() - conn.close() - - def test_rollback(self): - """Test rollbacking a transaction with several statements.""" - want_row = (2, "first-name", "last-name", "test.email@domen.ru") - # connect to the test database - conn = Connection(Config.INSTANCE, self._db) - cursor = conn.cursor() - - cursor.execute( - """ -INSERT INTO contacts (contact_id, first_name, last_name, email) -VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - conn.commit() - - # execute several DMLs with one transaction - cursor.execute( - """ -UPDATE contacts -SET first_name = 'updated-first-name' -WHERE first_name = 'first-name' -""" - ) - cursor.execute( - """ -UPDATE contacts -SET email = 'test.email_updated@domen.ru' -WHERE email = 'test.email@domen.ru' -""" - ) - conn.rollback() - - # read the resulting data from the database - cursor.execute("SELECT * FROM contacts") - got_rows = cursor.fetchall() - conn.commit() - - self.assertEqual(got_rows, [want_row]) - - cursor.close() - conn.close() - - def test_autocommit_mode_change(self): - """Test auto committing a transaction on `autocommit` mode change.""" - want_row = ( - 2, - "updated-first-name", - "last-name", - "test.email@domen.ru", - ) - # connect to the test database - conn = Connection(Config.INSTANCE, self._db) - cursor = conn.cursor() - - cursor.execute( - """ -INSERT INTO contacts (contact_id, first_name, last_name, email) -VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - cursor.execute( - """ -UPDATE contacts -SET first_name = 'updated-first-name' -WHERE first_name = 'first-name' -""" - ) - conn.autocommit = True - - # read the resulting data from the database - cursor.execute("SELECT * FROM contacts") - got_rows = cursor.fetchall() - - self.assertEqual(got_rows, [want_row]) - - cursor.close() - conn.close() - - def test_rollback_on_connection_closing(self): - """ - When closing a connection all the pending transactions - must be rollbacked. Testing if it's working this way. - """ - want_row = (1, "first-name", "last-name", "test.email@domen.ru") - # connect to the test database - conn = Connection(Config.INSTANCE, self._db) - cursor = conn.cursor() - - cursor.execute( - """ -INSERT INTO contacts (contact_id, first_name, last_name, email) -VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - conn.commit() - - cursor.execute( - """ -UPDATE contacts -SET first_name = 'updated-first-name' -WHERE first_name = 'first-name' -""" - ) - conn.close() - - # connect again, as the previous connection is no-op after closing - conn = Connection(Config.INSTANCE, self._db) - cursor = conn.cursor() - - # read the resulting data from the database - cursor.execute("SELECT * FROM contacts") - got_rows = cursor.fetchall() - conn.commit() - - self.assertEqual(got_rows, [want_row]) - - cursor.close() - conn.close() - - -def clear_table(transaction): - """Clear the test table.""" - transaction.execute_update("DELETE FROM contacts WHERE true") diff --git a/tests/unit/spanner_dbapi/__init__.py b/tests/unit/spanner_dbapi/__init__.py deleted file mode 100644 index 6b607710ed..0000000000 --- a/tests/unit/spanner_dbapi/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py deleted file mode 100644 index e5316d254e..0000000000 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Cloud Spanner DB-API Connection class unit tests.""" - -import unittest - -from unittest import mock - - -class TestHelpers(unittest.TestCase): - def test__execute_insert_heterogenous(self): - from google.cloud.spanner_dbapi import _helpers - - sql = "sql" - params = (sql, None) - with mock.patch( - "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner", - return_value=params, - ) as mock_pyformat: - with mock.patch( - "google.cloud.spanner_dbapi._helpers.get_param_types", - return_value=None, - ) as mock_param_types: - transaction = mock.MagicMock() - transaction.execute_sql = mock_execute = mock.MagicMock() - _helpers._execute_insert_heterogenous(transaction, [params]) - - mock_pyformat.assert_called_once_with(params[0], params[1]) - mock_param_types.assert_called_once_with(None) - mock_execute.assert_called_once_with( - sql, params=None, param_types=None - ) - - def test__execute_insert_homogenous(self): - from google.cloud.spanner_dbapi import _helpers - - transaction = mock.MagicMock() - transaction.insert = mock.MagicMock() - parts = mock.MagicMock() - parts.get = mock.MagicMock(return_value=0) - - _helpers._execute_insert_homogenous(transaction, parts) - transaction.insert.assert_called_once_with(0, 0, 0) - - def test_handle_insert(self): - from google.cloud.spanner_dbapi import _helpers - - connection = mock.MagicMock() - connection.database.run_in_transaction = mock_run_in = mock.MagicMock() - sql = "sql" - parts = mock.MagicMock() - with mock.patch( - "google.cloud.spanner_dbapi._helpers.parse_insert", - return_value=parts, - ): - parts.get = mock.MagicMock(return_value=True) - mock_run_in.return_value = 0 - result = _helpers.handle_insert(connection, sql, None) - self.assertEqual(result, 0) - - parts.get = mock.MagicMock(return_value=False) - mock_run_in.return_value = 1 - result = _helpers.handle_insert(connection, sql, None) - self.assertEqual(result, 1) - - -class TestColumnInfo(unittest.TestCase): - def test_ctor(self): - from google.cloud.spanner_dbapi.cursor import ColumnInfo - - name = "col-name" - type_code = 8 - display_size = 5 - internal_size = 10 - precision = 3 - scale = None - null_ok = False - - cols = ColumnInfo( - name, - type_code, - display_size, - internal_size, - precision, - scale, - null_ok, - ) - - self.assertEqual(cols.name, name) - self.assertEqual(cols.type_code, type_code) - self.assertEqual(cols.display_size, display_size) - self.assertEqual(cols.internal_size, internal_size) - self.assertEqual(cols.precision, precision) - self.assertEqual(cols.scale, scale) - self.assertEqual(cols.null_ok, null_ok) - self.assertEqual( - cols.fields, - ( - name, - type_code, - display_size, - internal_size, - precision, - scale, - null_ok, - ), - ) - - def test___get_item__(self): - from google.cloud.spanner_dbapi.cursor import ColumnInfo - - fields = ("col-name", 8, 5, 10, 3, None, False) - cols = ColumnInfo(*fields) - - for i in range(0, 7): - self.assertEqual(cols[i], fields[i]) - - def test___str__(self): - from google.cloud.spanner_dbapi.cursor import ColumnInfo - - cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) - - self.assertEqual( - str(cols), - "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", - ) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py deleted file mode 100644 index 99aa0aa47b..0000000000 --- a/tests/unit/spanner_dbapi/test_connection.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Cloud Spanner DB-API Connection class unit tests.""" - -import unittest -import warnings - -from unittest import mock - - -def _make_credentials(): - from google.auth import credentials - - class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): - pass - - return mock.Mock(spec=_CredentialsWithScopes) - - -class TestConnection(unittest.TestCase): - - PROJECT = "test-project" - INSTANCE = "test-instance" - DATABASE = "test-database" - USER_AGENT = "user-agent" - CREDENTIALS = _make_credentials() - - def _get_client_info(self): - from google.api_core.gapic_v1.client_info import ClientInfo - - return ClientInfo(user_agent=self.USER_AGENT) - - def _make_connection(self, pool=None): - from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_v1.instance import Instance - - # We don't need a real Client object to test the constructor - instance = Instance(self.INSTANCE, client=None) - database = instance.database(self.DATABASE) - - conn = Connection(instance, database) - if pool is not None: - conn._own_pool = False - - return conn - - def test_property_autocommit_setter(self): - from google.cloud.spanner_dbapi import Connection - - connection = Connection(self.INSTANCE, self.DATABASE) - - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.commit" - ) as mock_commit: - connection.autocommit = True - mock_commit.assert_called_once_with() - self.assertEqual(connection._autocommit, True) - - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.commit" - ) as mock_commit: - connection.autocommit = False - mock_commit.assert_not_called() - self.assertEqual(connection._autocommit, False) - - def test_property_database(self): - from google.cloud.spanner_v1.database import Database - - connection = self._make_connection() - self.assertIsInstance(connection.database, Database) - self.assertEqual(connection.database, connection._database) - - def test_property_instance(self): - from google.cloud.spanner_v1.instance import Instance - - connection = self._make_connection() - self.assertIsInstance(connection.instance, Instance) - self.assertEqual(connection.instance, connection._instance) - - def test__session_checkout(self): - from google.cloud.spanner_dbapi import Connection - - with mock.patch( - "google.cloud.spanner_v1.database.Database", - ) as mock_database: - mock_database._pool = mock.MagicMock() - mock_database._pool.get = mock.MagicMock( - return_value="db_session_pool" - ) - connection = Connection(self.INSTANCE, mock_database) - - connection._session_checkout() - mock_database._pool.get.assert_called_once_with() - self.assertEqual(connection._session, "db_session_pool") - - connection._session = "db_session" - connection._session_checkout() - self.assertEqual(connection._session, "db_session") - - def test__release_session(self): - from google.cloud.spanner_dbapi import Connection - - with mock.patch( - "google.cloud.spanner_v1.database.Database", - ) as mock_database: - mock_database._pool = mock.MagicMock() - mock_database._pool.put = mock.MagicMock() - connection = Connection(self.INSTANCE, mock_database) - connection._session = "session" - - connection._release_session() - mock_database._pool.put.assert_called_once_with("session") - self.assertIsNone(connection._session) - - def test_transaction_checkout(self): - from google.cloud.spanner_dbapi import Connection - - connection = Connection(self.INSTANCE, self.DATABASE) - connection._session_checkout = mock_checkout = mock.MagicMock( - autospec=True - ) - connection.transaction_checkout() - mock_checkout.assert_called_once_with() - - connection._transaction = mock_transaction = mock.MagicMock() - mock_transaction.committed = mock_transaction.rolled_back = False - self.assertEqual(connection.transaction_checkout(), mock_transaction) - - connection._autocommit = True - self.assertIsNone(connection.transaction_checkout()) - - def test_close(self): - from google.cloud.spanner_dbapi import connect, InterfaceError - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - self.assertFalse(connection.is_closed) - connection.close() - self.assertTrue(connection.is_closed) - - with self.assertRaises(InterfaceError): - connection.cursor() - - connection._transaction = mock_transaction = mock.MagicMock() - mock_transaction.committed = mock_transaction.rolled_back = False - mock_transaction.rollback = mock_rollback = mock.MagicMock() - connection.close() - mock_rollback.assert_called_once_with() - - @mock.patch.object(warnings, "warn") - def test_commit(self, mock_warn): - from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import ( - AUTOCOMMIT_MODE_WARNING, - ) - - connection = Connection(self.INSTANCE, self.DATABASE) - - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection._release_session" - ) as mock_release: - connection.commit() - mock_release.assert_not_called() - - connection._transaction = mock_transaction = mock.MagicMock() - mock_transaction.commit = mock_commit = mock.MagicMock() - - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection._release_session" - ) as mock_release: - connection.commit() - mock_commit.assert_called_once_with() - mock_release.assert_called_once_with() - - connection._autocommit = True - connection.commit() - mock_warn.assert_called_once_with( - AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 - ) - - @mock.patch.object(warnings, "warn") - def test_rollback(self, mock_warn): - from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import ( - AUTOCOMMIT_MODE_WARNING, - ) - - connection = Connection(self.INSTANCE, self.DATABASE) - - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection._release_session" - ) as mock_release: - connection.rollback() - mock_release.assert_not_called() - - connection._transaction = mock_transaction = mock.MagicMock() - mock_transaction.rollback = mock_rollback = mock.MagicMock() - - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection._release_session" - ) as mock_release: - connection.rollback() - mock_rollback.assert_called_once_with() - mock_release.assert_called_once_with() - - connection._autocommit = True - connection.rollback() - mock_warn.assert_called_once_with( - AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 - ) - - def test_run_prior_DDL_statements(self): - from google.cloud.spanner_dbapi import Connection, InterfaceError - - with mock.patch( - "google.cloud.spanner_v1.database.Database", autospec=True, - ) as mock_database: - connection = Connection(self.INSTANCE, mock_database) - - connection.run_prior_DDL_statements() - mock_database.update_ddl.assert_not_called() - - ddl = ["ddl"] - connection._ddl_statements = ddl - - connection.run_prior_DDL_statements() - mock_database.update_ddl.assert_called_once_with(ddl) - - connection.is_closed = True - - with self.assertRaises(InterfaceError): - connection.run_prior_DDL_statements() - - def test_context(self): - connection = self._make_connection() - with connection as conn: - self.assertEqual(conn, connection) - - self.assertTrue(connection.is_closed) - - def test_connect(self): - from google.cloud.spanner_dbapi import Connection, connect - - with mock.patch("google.cloud.spanner_v1.Client"): - with mock.patch( - "google.api_core.gapic_v1.client_info.ClientInfo", - return_value=self._get_client_info(), - ): - connection = connect( - self.INSTANCE, - self.DATABASE, - self.PROJECT, - self.CREDENTIALS, - self.USER_AGENT, - ) - self.assertIsInstance(connection, Connection) - - def test_connect_instance_not_found(self): - from google.cloud.spanner_dbapi import connect - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=False, - ): - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - def test_connect_database_not_found(self): - from google.cloud.spanner_dbapi import connect - - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=False, - ): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - def test_default_sessions_pool(self): - from google.cloud.spanner_dbapi import connect - - with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - self.assertIsNotNone(connection.database._pool) - - def test_sessions_pool(self): - from google.cloud.spanner_dbapi import connect - from google.cloud.spanner_v1.pool import FixedSizePool - - database_id = "test-database" - pool = FixedSizePool() - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connect("test-instance", database_id, pool=pool) - database_mock.assert_called_once_with(database_id, pool=pool) - - def test_clearing_pool_on_close(self): - connection = self._make_connection() - with mock.patch.object( - connection.database._pool, "clear" - ) as pool_clear_mock: - connection.close() - pool_clear_mock.assert_called_once_with() - - def test_global_pool(self): - connection = self._make_connection(pool=mock.Mock()) - with mock.patch.object( - connection.database._pool, "clear" - ) as pool_clear_mock: - connection.close() - assert not pool_clear_mock.called diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py deleted file mode 100644 index a73265e932..0000000000 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ /dev/null @@ -1,460 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Cursor() class unit tests.""" - -import unittest - -from unittest import mock - - -class TestCursor(unittest.TestCase): - - INSTANCE = "test-instance" - DATABASE = "test-database" - - def _get_target_class(self): - from google.cloud.spanner_dbapi import Cursor - - return Cursor - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - - def _make_connection(self, *args, **kwargs): - from google.cloud.spanner_dbapi import Connection - - return Connection(*args, **kwargs) - - def test_property_connection(self): - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - self.assertEqual(cursor.connection, connection) - - def test_property_description(self): - from google.cloud.spanner_dbapi._helpers import ColumnInfo - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - - self.assertIsNone(cursor.description) - cursor._result_set = res_set = mock.MagicMock() - res_set.metadata.row_type.fields = [mock.MagicMock()] - self.assertIsNotNone(cursor.description) - self.assertIsInstance(cursor.description[0], ColumnInfo) - - def test_property_rowcount(self): - from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - self.assertEqual(cursor.rowcount, _UNSET_COUNT) - - def test_callproc(self): - from google.cloud.spanner_dbapi.exceptions import InterfaceError - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - cursor._is_closed = True - with self.assertRaises(InterfaceError): - cursor.callproc(procname=None) - - def test_close(self): - from google.cloud.spanner_dbapi import connect, InterfaceError - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect(self.INSTANCE, self.DATABASE) - - cursor = connection.cursor() - self.assertFalse(cursor.is_closed) - - cursor.close() - - self.assertTrue(cursor.is_closed) - with self.assertRaises(InterfaceError): - cursor.execute("SELECT * FROM database") - - def test_do_execute_update(self): - from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - transaction = mock.MagicMock() - - def run_helper(ret_value): - transaction.execute_update.return_value = ret_value - res = cursor._do_execute_update( - transaction=transaction, sql="SELECT * WHERE true", params={}, - ) - return res - - expected = "good" - self.assertEqual(run_helper(expected), expected) - self.assertEqual(cursor._row_count, _UNSET_COUNT) - - expected = 1234 - self.assertEqual(run_helper(expected), expected) - self.assertEqual(cursor._row_count, expected) - - def test_execute_programming_error(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - cursor.connection = None - with self.assertRaises(ProgrammingError): - cursor.execute(sql="") - - def test_execute_attribute_error(self): - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - - with self.assertRaises(AttributeError): - cursor.execute(sql="") - - def test_execute_autocommit_off(self): - from google.cloud.spanner_dbapi.utils import PeekIterator - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - cursor.connection._autocommit = False - cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) - - cursor.execute("sql") - self.assertIsInstance(cursor._result_set, mock.MagicMock) - self.assertIsInstance(cursor._itr, PeekIterator) - - def test_execute_statement(self): - from google.cloud.spanner_dbapi import parse_utils - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_DDL, - ) as mock_classify_stmt: - sql = "sql" - cursor.execute(sql=sql) - mock_classify_stmt.assert_called_once_with(sql) - self.assertEqual(cursor.connection._ddl_statements, [sql]) - - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_NON_UPDATING, - ): - with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", - return_value=parse_utils.STMT_NON_UPDATING, - ) as mock_handle_ddl: - connection.autocommit = True - sql = "sql" - cursor.execute(sql=sql) - mock_handle_ddl.assert_called_once_with(sql, None) - - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_INSERT, - ): - with mock.patch( - "google.cloud.spanner_dbapi._helpers.handle_insert", - return_value=parse_utils.STMT_INSERT, - ) as mock_handle_insert: - sql = "sql" - cursor.execute(sql=sql) - mock_handle_insert.assert_called_once_with( - connection, sql, None - ) - - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value="other_statement", - ): - cursor.connection._database = mock_db = mock.MagicMock() - mock_db.run_in_transaction = mock_run_in = mock.MagicMock() - sql = "sql" - cursor.execute(sql=sql) - mock_run_in.assert_called_once_with( - cursor._do_execute_update, sql, None - ) - - def test_execute_integrity_error(self): - from google.api_core import exceptions - from google.cloud.spanner_dbapi.exceptions import IntegrityError - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - side_effect=exceptions.AlreadyExists("message"), - ): - with self.assertRaises(IntegrityError): - cursor.execute(sql="sql") - - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - side_effect=exceptions.FailedPrecondition("message"), - ): - with self.assertRaises(IntegrityError): - cursor.execute(sql="sql") - - def test_execute_invalid_argument(self): - from google.api_core import exceptions - from google.cloud.spanner_dbapi.exceptions import ProgrammingError - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - side_effect=exceptions.InvalidArgument("message"), - ): - with self.assertRaises(ProgrammingError): - cursor.execute(sql="sql") - - def test_execute_internal_server_error(self): - from google.api_core import exceptions - from google.cloud.spanner_dbapi.exceptions import OperationalError - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - side_effect=exceptions.InternalServerError("message"), - ): - with self.assertRaises(OperationalError): - cursor.execute(sql="sql") - - def test_executemany_on_closed_cursor(self): - from google.cloud.spanner_dbapi import InterfaceError - from google.cloud.spanner_dbapi import connect - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor.close() - - with self.assertRaises(InterfaceError): - cursor.executemany( - """SELECT * FROM table1 WHERE "col1" = @a1""", () - ) - - def test_executemany(self): - from google.cloud.spanner_dbapi import connect - - operation = """SELECT * FROM table1 WHERE "col1" = @a1""" - params_seq = ((1,), (2,)) - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.execute" - ) as execute_mock: - cursor.executemany(operation, params_seq) - - execute_mock.assert_has_calls( - (mock.call(operation, (1,)), mock.call(operation, (2,))) - ) - - def test_fetchone(self): - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - lst = [1, 2, 3] - cursor._itr = iter(lst) - for i in range(len(lst)): - self.assertEqual(cursor.fetchone(), lst[i]) - self.assertIsNone(cursor.fetchone()) - - def test_fetchmany(self): - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - lst = [(1,), (2,), (3,)] - cursor._itr = iter(lst) - - self.assertEqual(cursor.fetchmany(), [lst[0]]) - - result = cursor.fetchmany(len(lst)) - self.assertEqual(result, lst[1:]) - - def test_fetchall(self): - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - lst = [(1,), (2,), (3,)] - cursor._itr = iter(lst) - self.assertEqual(cursor.fetchall(), lst) - - def test_nextset(self): - from google.cloud.spanner_dbapi import exceptions - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - cursor.close() - with self.assertRaises(exceptions.InterfaceError): - cursor.nextset() - - def test_setinputsizes(self): - from google.cloud.spanner_dbapi import exceptions - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - cursor.close() - with self.assertRaises(exceptions.InterfaceError): - cursor.setinputsizes(sizes=None) - - def test_setoutputsize(self): - from google.cloud.spanner_dbapi import exceptions - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - cursor = self._make_one(connection) - cursor.close() - with self.assertRaises(exceptions.InterfaceError): - cursor.setoutputsize(size=None) - - # def test_handle_insert(self): - # pass - # - # def test_do_execute_insert_heterogenous(self): - # pass - # - # def test_do_execute_insert_homogenous(self): - # pass - - def test_handle_dql(self): - from google.cloud.spanner_dbapi import utils - from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT - - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - connection.database.snapshot.return_value.__enter__.return_value = ( - mock_snapshot - ) = mock.MagicMock() - cursor = self._make_one(connection) - - mock_snapshot.execute_sql.return_value = int(0) - cursor._handle_DQL("sql", params=None) - self.assertEqual(cursor._row_count, 0) - self.assertIsNone(cursor._itr) - - mock_snapshot.execute_sql.return_value = "0" - cursor._handle_DQL("sql", params=None) - self.assertEqual(cursor._result_set, "0") - self.assertIsInstance(cursor._itr, utils.PeekIterator) - self.assertEqual(cursor._row_count, _UNSET_COUNT) - - def test_context(self): - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - with cursor as c: - self.assertEqual(c, cursor) - - self.assertTrue(c.is_closed) - - def test_next(self): - from google.cloud.spanner_dbapi import exceptions - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - with self.assertRaises(exceptions.ProgrammingError): - cursor.__next__() - - lst = [(1,), (2,), (3,)] - cursor._itr = iter(lst) - i = 0 - for c in cursor._itr: - self.assertEqual(c, lst[i]) - i += 1 - - def test_iter(self): - from google.cloud.spanner_dbapi import exceptions - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - with self.assertRaises(exceptions.ProgrammingError): - _ = iter(cursor) - - iterator = iter([(1,), (2,), (3,)]) - cursor._itr = iterator - self.assertEqual(iter(cursor), iterator) - - def test_list_tables(self): - from google.cloud.spanner_dbapi import _helpers - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - - table_list = ["table1", "table2", "table3"] - with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", - return_value=table_list, - ) as mock_run_sql: - cursor.list_tables() - mock_run_sql.assert_called_once_with(_helpers.SQL_LIST_TABLES) - - def test_run_sql_in_snapshot(self): - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) - connection.database.snapshot.return_value.__enter__.return_value = ( - mock_snapshot - ) = mock.MagicMock() - cursor = self._make_one(connection) - - results = 1, 2, 3 - mock_snapshot.execute_sql.return_value = results - self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results)) - - def test_get_table_column_schema(self): - from google.cloud.spanner_dbapi.cursor import ColumnDetails - from google.cloud.spanner_dbapi import _helpers - from google.cloud.spanner_v1 import param_types - - connection = self._make_connection(self.INSTANCE, self.DATABASE) - cursor = self._make_one(connection) - - column_name = "column1" - is_nullable = "YES" - spanner_type = "spanner_type" - rows = [(column_name, is_nullable, spanner_type)] - expected = { - column_name: ColumnDetails( - null_ok=True, spanner_type=spanner_type, - ) - } - with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", - return_value=rows, - ) as mock_run_sql: - table_name = "table1" - result = cursor.get_table_column_schema(table_name=table_name) - mock_run_sql.assert_called_once_with( - sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, - params={"table_name": table_name}, - param_types={"table_name": param_types.STRING}, - ) - self.assertEqual(result, expected) diff --git a/tests/unit/spanner_dbapi/test_globals.py b/tests/unit/spanner_dbapi/test_globals.py deleted file mode 100644 index 3f8360e2ea..0000000000 --- a/tests/unit/spanner_dbapi/test_globals.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import unittest - - -class TestDBAPIGlobals(unittest.TestCase): - def test_apilevel(self): - from google.cloud.spanner_dbapi import apilevel - from google.cloud.spanner_dbapi import paramstyle - from google.cloud.spanner_dbapi import threadsafety - - self.assertEqual(apilevel, "2.0", "We implement PEP-0249 version 2.0") - self.assertEqual(paramstyle, "format", "Cloud Spanner uses @param") - self.assertEqual( - threadsafety, 1, "Threads may share module but not connections" - ) diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py deleted file mode 100644 index d68e4118fd..0000000000 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ /dev/null @@ -1,444 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import unittest - -from google.cloud.spanner_v1 import param_types - - -class TestParseUtils(unittest.TestCase): - def test_classify_stmt(self): - from google.cloud.spanner_dbapi.parse_utils import STMT_DDL - from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT - from google.cloud.spanner_dbapi.parse_utils import STMT_NON_UPDATING - from google.cloud.spanner_dbapi.parse_utils import STMT_UPDATING - from google.cloud.spanner_dbapi.parse_utils import classify_stmt - - cases = ( - ("SELECT 1", STMT_NON_UPDATING), - ("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING), - ( - "WITH sq AS (SELECT SchoolID FROM Roster) SELECT * from sq", - STMT_NON_UPDATING, - ), - ( - "CREATE TABLE django_content_type (id STRING(64) NOT NULL, name STRING(100) " - "NOT NULL, app_label STRING(100) NOT NULL, model STRING(100) NOT NULL) PRIMARY KEY(id)", - STMT_DDL, - ), - ( - "CREATE INDEX SongsBySingerAlbumSongNameDesc ON " - "Songs(SingerId, AlbumId, SongName DESC), INTERLEAVE IN Albums", - STMT_DDL, - ), - ("CREATE INDEX SongsBySongName ON Songs(SongName)", STMT_DDL), - ( - "CREATE INDEX AlbumsByAlbumTitle2 ON Albums(AlbumTitle) STORING (MarketingBudget)", - STMT_DDL, - ), - ("INSERT INTO table (col1) VALUES (1)", STMT_INSERT), - ("UPDATE table SET col1 = 1 WHERE col1 = NULL", STMT_UPDATING), - ) - - for query, want_class in cases: - self.assertEqual(classify_stmt(query), want_class) - - def test_parse_insert(self): - from google.cloud.spanner_dbapi.parse_utils import parse_insert - from google.cloud.spanner_dbapi.exceptions import ProgrammingError - - with self.assertRaises(ProgrammingError): - parse_insert("bad-sql", None) - - cases = [ - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - [1, 2, 3, 4, 5, 6], - { - "sql_params_list": [ - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - (1, 2, 3), - ), - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - (4, 5, 6), - ), - ] - }, - ), - ( - "INSERT INTO django_migrations(app, name, applied) VALUES (%s, %s, %s)", - [1, 2, 3, 4, 5, 6], - { - "sql_params_list": [ - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - (1, 2, 3), - ), - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", - (4, 5, 6), - ), - ] - }, - ), - ( - "INSERT INTO sales.addresses (street, city, state, zip_code) " - "SELECT street, city, state, zip_code FROM sales.customers" - "ORDER BY first_name, last_name", - None, - { - "sql_params_list": [ - ( - "INSERT INTO sales.addresses (street, city, state, zip_code) " - "SELECT street, city, state, zip_code FROM sales.customers" - "ORDER BY first_name, last_name", - None, - ) - ] - }, - ), - ( - "INSERT INTO ap (n, ct, cn) " - "VALUES (%s, %s, %s), (%s, %s, %s), (%s, %s, %s),(%s, %s, %s)", - (1, 2, 3, 4, 5, 6, 7, 8, 9), - { - "sql_params_list": [ - ( - "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", - (1, 2, 3), - ), - ( - "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", - (4, 5, 6), - ), - ( - "INSERT INTO ap (n, ct, cn) VALUES (%s, %s, %s)", - (7, 8, 9), - ), - ] - }, - ), - ( - "INSERT INTO `no` (`yes`) VALUES (%s)", - (1, 4, 5), - { - "sql_params_list": [ - ("INSERT INTO `no` (`yes`) VALUES (%s)", (1,)), - ("INSERT INTO `no` (`yes`) VALUES (%s)", (4,)), - ("INSERT INTO `no` (`yes`) VALUES (%s)", (5,)), - ] - }, - ), - ( - "INSERT INTO T (f1, f2) VALUES (1, 2)", - None, - { - "sql_params_list": [ - ("INSERT INTO T (f1, f2) VALUES (1, 2)", None) - ] - }, - ), - ( - "INSERT INTO `no` (`yes`, tiff) VALUES (%s, LOWER(%s)), (%s, %s), (%s, %s)", - (1, "FOO", 5, 10, 11, 29), - { - "sql_params_list": [ - ( - "INSERT INTO `no` (`yes`, tiff) VALUES(%s, LOWER(%s))", - (1, "FOO"), - ), - ( - "INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", - (5, 10), - ), - ( - "INSERT INTO `no` (`yes`, tiff) VALUES(%s, %s)", - (11, 29), - ), - ] - }, - ), - ] - - sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" - with self.assertRaises(ProgrammingError): - parse_insert(sql, None) - - for sql, params, want in cases: - with self.subTest(sql=sql): - got = parse_insert(sql, params) - self.assertEqual( - got, want, "Mismatch with parse_insert of `%s`" % sql - ) - - def test_parse_insert_invalid(self): - from google.cloud.spanner_dbapi import exceptions - from google.cloud.spanner_dbapi.parse_utils import parse_insert - - cases = [ - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)", - [1, 2, 3, 4, 5, 6, 7], - "len\\(params\\)=7 MUST be a multiple of len\\(pyformat_args\\)=3", - ), - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s))", - [1, 2, 3, 4, 5, 6, 7], - "Invalid length: VALUES\\(...\\) len: 6 != len\\(params\\): 7", - ), - ( - "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, LOWER(%s)))", - [1, 2, 3, 4, 5, 6], - "VALUES: expected `,` got \\) in \\)", - ), - ] - - for sql, params, wantException in cases: - with self.subTest(sql=sql): - self.assertRaisesRegex( - exceptions.ProgrammingError, - wantException, - lambda: parse_insert(sql, params), - ) - - def test_rows_for_insert_or_update(self): - from google.cloud.spanner_dbapi.parse_utils import ( - rows_for_insert_or_update, - ) - from google.cloud.spanner_dbapi.exceptions import Error - - with self.assertRaises(Error): - rows_for_insert_or_update([0], [[]]) - - with self.assertRaises(Error): - rows_for_insert_or_update([0], None, ["0", "%s"]) - - cases = [ - ( - ["id", "app", "name"], - [(5, "ap", "n"), (6, "bp", "m")], - None, - [(5, "ap", "n"), (6, "bp", "m")], - ), - ( - ["app", "name"], - [("ap", "n"), ("bp", "m")], - None, - [("ap", "n"), ("bp", "m")], - ), - ( - ["app", "name", "fn"], - ["ap", "n", "f1", "bp", "m", "f2", "cp", "o", "f3"], - ["(%s, %s, %s)", "(%s, %s, %s)", "(%s, %s, %s)"], - [("ap", "n", "f1"), ("bp", "m", "f2"), ("cp", "o", "f3")], - ), - ( - ["app", "name", "fn", "ln"], - [ - ("ap", "n", (45, "nested"), "ll"), - ("bp", "m", "f2", "mt"), - ("fp", "cp", "o", "f3"), - ], - None, - [ - ("ap", "n", (45, "nested"), "ll"), - ("bp", "m", "f2", "mt"), - ("fp", "cp", "o", "f3"), - ], - ), - ( - ["app", "name", "fn"], - ["ap", "n", "f1"], - None, - [("ap", "n", "f1")], - ), - ] - - for i, (columns, params, pyformat_args, want) in enumerate(cases): - with self.subTest(i=i): - got = rows_for_insert_or_update(columns, params, pyformat_args) - self.assertEqual(got, want) - - def test_sql_pyformat_args_to_spanner(self): - import decimal - - from google.cloud.spanner_dbapi.parse_utils import ( - sql_pyformat_args_to_spanner, - ) - - cases = [ - ( - ( - "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s", - (10, "abc", "y**$22l3f"), - ), - ( - "SELECT * from t WHERE f1=@a0, f2 = @a1, f3=@a2", - {"a0": 10, "a1": "abc", "a2": "y**$22l3f"}, - ), - ), - ( - ( - "INSERT INTO t (f1, f2, f2) VALUES (%s, %s, %s)", - ("app", "name", "applied"), - ), - ( - "INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)", - {"a0": "app", "a1": "name", "a2": "applied"}, - ), - ), - ( - ( - "INSERT INTO t (f1, f2, f2) VALUES (%(f1)s, %(f2)s, %(f3)s)", - {"f1": "app", "f2": "name", "f3": "applied"}, - ), - ( - "INSERT INTO t (f1, f2, f2) VALUES (@a0, @a1, @a2)", - {"a0": "app", "a1": "name", "a2": "applied"}, - ), - ), - ( - # Intentionally using a dict with more keys than will be resolved. - ( - "SELECT * from t WHERE f1=%(f1)s", - {"f1": "app", "f2": "name"}, - ), - ("SELECT * from t WHERE f1=@a0", {"a0": "app"}), - ), - ( - # No args to replace, we MUST return the original params dict - # since it might be useful to pass to the next user. - ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), - ("SELECT * from t WHERE id=10", {"f1": "app", "f2": "name"}), - ), - ( - ( - "SELECT (an.p + %s) AS np FROM an WHERE (an.p + %s) = %s", - (1, 1.0, decimal.Decimal("31")), - ), - ( - "SELECT (an.p + @a0) AS np FROM an WHERE (an.p + @a1) = @a2", - {"a0": 1, "a1": 1.0, "a2": 31.0}, - ), - ), - ] - for ((sql_in, params), sql_want) in cases: - with self.subTest(sql=sql_in): - got_sql, got_named_args = sql_pyformat_args_to_spanner( - sql_in, params - ) - want_sql, want_named_args = sql_want - self.assertEqual(got_sql, want_sql, "SQL does not match") - self.assertEqual( - got_named_args, want_named_args, "Named args do not match" - ) - - def test_sql_pyformat_args_to_spanner_invalid(self): - from google.cloud.spanner_dbapi import exceptions - from google.cloud.spanner_dbapi.parse_utils import ( - sql_pyformat_args_to_spanner, - ) - - cases = [ - ( - "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s, extra=%s", - (10, "abc", "y**$22l3f"), - ) - ] - for sql, params in cases: - with self.subTest(sql=sql): - self.assertRaisesRegex( - exceptions.Error, - "pyformat_args mismatch", - lambda: sql_pyformat_args_to_spanner(sql, params), - ) - - def test_cast_for_spanner(self): - import decimal - - from google.cloud.spanner_dbapi.parse_utils import cast_for_spanner - - value = decimal.Decimal(3) - self.assertEqual(cast_for_spanner(value), float(3.0)) - self.assertEqual(cast_for_spanner(5), 5) - self.assertEqual(cast_for_spanner("string"), "string") - - def test_get_param_types(self): - import datetime - - from google.cloud.spanner_dbapi.parse_utils import DateStr - from google.cloud.spanner_dbapi.parse_utils import TimestampStr - from google.cloud.spanner_dbapi.parse_utils import get_param_types - - params = { - "a1": 10, - "b1": "string", - "c1": 10.39, - "d1": TimestampStr("2005-08-30T01:01:01.000001Z"), - "e1": DateStr("2019-12-05"), - "f1": True, - "g1": datetime.datetime(2011, 9, 1, 13, 20, 30), - "h1": datetime.date(2011, 9, 1), - "i1": b"bytes", - "j1": None, - } - want_types = { - "a1": param_types.INT64, - "b1": param_types.STRING, - "c1": param_types.FLOAT64, - "d1": param_types.TIMESTAMP, - "e1": param_types.DATE, - "f1": param_types.BOOL, - "g1": param_types.TIMESTAMP, - "h1": param_types.DATE, - "i1": param_types.BYTES, - } - got_types = get_param_types(params) - self.assertEqual(got_types, want_types) - - def test_get_param_types_none(self): - from google.cloud.spanner_dbapi.parse_utils import get_param_types - - self.assertEqual(get_param_types(None), None) - - def test_ensure_where_clause(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError - from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause - - cases = ( - "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", - "UPDATE T SET A = 1 WHERE C1 = 1 AND C2 = 2", - "UPDATE T SET r=r*0.9 WHERE id IN (SELECT id FROM items WHERE r / w >= 1.3 AND q > 100)", - ) - err_cases = ( - "UPDATE (SELECT * FROM A JOIN c ON ai.id = c.id WHERE cl.ci = 1) SET d=5", - "DELETE * FROM TABLE", - ) - for sql in cases: - with self.subTest(sql=sql): - ensure_where_clause(sql) - - for sql in err_cases: - with self.subTest(sql=sql): - with self.assertRaises(ProgrammingError): - ensure_where_clause(sql) - - def test_escape_name(self): - from google.cloud.spanner_dbapi.parse_utils import escape_name - - cases = ( - ("SELECT", "`SELECT`"), - ("dashed-value", "`dashed-value`"), - ("with space", "`with space`"), - ("name", "name"), - ("", ""), - ) - for name, want in cases: - with self.subTest(name=name): - got = escape_name(name) - self.assertEqual(got, want) diff --git a/tests/unit/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py deleted file mode 100644 index d5baf9d824..0000000000 --- a/tests/unit/spanner_dbapi/test_parser.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import unittest - -from unittest import mock - - -class TestParser(unittest.TestCase): - def test_func(self): - from google.cloud.spanner_dbapi.parser import FUNC - from google.cloud.spanner_dbapi.parser import a_args - from google.cloud.spanner_dbapi.parser import expect - from google.cloud.spanner_dbapi.parser import func - from google.cloud.spanner_dbapi.parser import pyfmt_str - - cases = [ - ("_91())", ")", func("_91", a_args([]))), - ("_a()", "", func("_a", a_args([]))), - ("___()", "", func("___", a_args([]))), - ("abc()", "", func("abc", a_args([]))), - ( - "AF112(%s, LOWER(%s, %s), rand(%s, %s, TAN(%s, %s)))", - "", - func( - "AF112", - a_args( - [ - pyfmt_str, - func("LOWER", a_args([pyfmt_str, pyfmt_str])), - func( - "rand", - a_args( - [ - pyfmt_str, - pyfmt_str, - func( - "TAN", - a_args([pyfmt_str, pyfmt_str]), - ), - ] - ), - ), - ] - ), - ), - ), - ] - - for text, want_unconsumed, want_parsed in cases: - with self.subTest(text=text): - got_unconsumed, got_parsed = expect(text, FUNC) - self.assertEqual(got_parsed, want_parsed) - self.assertEqual(got_unconsumed, want_unconsumed) - - def test_func_fail(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError - from google.cloud.spanner_dbapi.parser import FUNC - from google.cloud.spanner_dbapi.parser import expect - - cases = [ - ("", "FUNC: `` does not begin with `a-zA-z` nor a `_`"), - ("91", "FUNC: `91` does not begin with `a-zA-z` nor a `_`"), - ("_91", "supposed to begin with `\\(`"), - ("_91(", "supposed to end with `\\)`"), - ("_.()", "supposed to begin with `\\(`"), - ("_a.b()", "supposed to begin with `\\(`"), - ] - - for text, wantException in cases: - with self.subTest(text=text): - self.assertRaisesRegex( - ProgrammingError, wantException, lambda: expect(text, FUNC) - ) - - def test_func_eq(self): - from google.cloud.spanner_dbapi.parser import func - - func1 = func("func1", None) - func2 = func("func2", None) - self.assertFalse(func1 == object) - self.assertFalse(func1 == func2) - func2.name = func1.name - func1.args = 0 - func2.args = "0" - self.assertFalse(func1 == func2) - func1.args = [0] - func2.args = [0, 0] - self.assertFalse(func1 == func2) - func2.args = func1.args - self.assertTrue(func1 == func2) - - def test_a_args(self): - from google.cloud.spanner_dbapi.parser import ARGS - from google.cloud.spanner_dbapi.parser import a_args - from google.cloud.spanner_dbapi.parser import expect - from google.cloud.spanner_dbapi.parser import func - from google.cloud.spanner_dbapi.parser import pyfmt_str - - cases = [ - ("()", "", a_args([])), - ("(%s)", "", a_args([pyfmt_str])), - ("(%s,)", "", a_args([pyfmt_str])), - ("(%s),", ",", a_args([pyfmt_str])), - ( - "(%s,%s, f1(%s, %s))", - "", - a_args( - [ - pyfmt_str, - pyfmt_str, - func("f1", a_args([pyfmt_str, pyfmt_str])), - ] - ), - ), - ] - - for text, want_unconsumed, want_parsed in cases: - with self.subTest(text=text): - got_unconsumed, got_parsed = expect(text, ARGS) - self.assertEqual(got_parsed, want_parsed) - self.assertEqual(got_unconsumed, want_unconsumed) - - def test_a_args_fail(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError - from google.cloud.spanner_dbapi.parser import ARGS - from google.cloud.spanner_dbapi.parser import expect - - cases = [ - ("", "ARGS: supposed to begin with `\\(`"), - ("(", "ARGS: supposed to end with `\\)`"), - (")", "ARGS: supposed to begin with `\\(`"), - ("(%s,%s, f1(%s, %s), %s", "ARGS: supposed to end with `\\)`"), - ] - - for text, wantException in cases: - with self.subTest(text=text): - self.assertRaisesRegex( - ProgrammingError, wantException, lambda: expect(text, ARGS) - ) - - def test_a_args_has_expr(self): - from google.cloud.spanner_dbapi.parser import a_args - - self.assertFalse(a_args([]).has_expr()) - self.assertTrue(a_args([[0]]).has_expr()) - - def test_a_args_eq(self): - from google.cloud.spanner_dbapi.parser import a_args - - a1 = a_args([0]) - self.assertFalse(a1 == object()) - a2 = a_args([0, 0]) - self.assertFalse(a1 == a2) - a1.argv = [0, 1] - self.assertFalse(a1 == a2) - a2.argv = [0, 1] - self.assertTrue(a1 == a2) - - def test_a_args_homogeneous(self): - from google.cloud.spanner_dbapi.parser import a_args - from google.cloud.spanner_dbapi.parser import terminal - - a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) - self.assertTrue(a_obj.homogenous()) - - a_obj = a_args([a_args([[object()]]) for _ in range(10)]) - self.assertFalse(a_obj.homogenous()) - - def test_a_args__is_equal_length(self): - from google.cloud.spanner_dbapi.parser import a_args - - a_obj = a_args([]) - self.assertTrue(a_obj._is_equal_length()) - - def test_values(self): - from google.cloud.spanner_dbapi.parser import a_args - from google.cloud.spanner_dbapi.parser import terminal - from google.cloud.spanner_dbapi.parser import values - - a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) - self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) - - def test_expect(self): - from google.cloud.spanner_dbapi.parser import ARGS - from google.cloud.spanner_dbapi.parser import EXPR - from google.cloud.spanner_dbapi.parser import FUNC - from google.cloud.spanner_dbapi.parser import expect - from google.cloud.spanner_dbapi.parser import pyfmt_str - from google.cloud.spanner_dbapi import exceptions - - with self.assertRaises(exceptions.ProgrammingError): - expect(word="", token=ARGS) - with self.assertRaises(exceptions.ProgrammingError): - expect(word="ABC", token=ARGS) - with self.assertRaises(exceptions.ProgrammingError): - expect(word="(", token=ARGS) - - expected = "", pyfmt_str - self.assertEqual(expect("%s", EXPR), expected) - - expected = expect("function()", FUNC) - self.assertEqual(expect("function()", EXPR), expected) - - with self.assertRaises(exceptions.ProgrammingError): - expect(word="", token="ABC") - - def test_expect_values(self): - from google.cloud.spanner_dbapi.parser import VALUES - from google.cloud.spanner_dbapi.parser import a_args - from google.cloud.spanner_dbapi.parser import expect - from google.cloud.spanner_dbapi.parser import func - from google.cloud.spanner_dbapi.parser import pyfmt_str - from google.cloud.spanner_dbapi.parser import values - - cases = [ - ("VALUES ()", "", values([a_args([])])), - ("VALUES", "", values([])), - ("VALUES(%s)", "", values([a_args([pyfmt_str])])), - (" VALUES (%s) ", "", values([a_args([pyfmt_str])])), - ("VALUES(%s, %s)", "", values([a_args([pyfmt_str, pyfmt_str])])), - ( - "VALUES(%s, %s, LOWER(%s, %s))", - "", - values( - [ - a_args( - [ - pyfmt_str, - pyfmt_str, - func("LOWER", a_args([pyfmt_str, pyfmt_str])), - ] - ) - ] - ), - ), - ( - "VALUES (UPPER(%s)), (%s)", - "", - values( - [ - a_args([func("UPPER", a_args([pyfmt_str]))]), - a_args([pyfmt_str]), - ] - ), - ), - ] - - for text, want_unconsumed, want_parsed in cases: - with self.subTest(text=text): - got_unconsumed, got_parsed = expect(text, VALUES) - self.assertEqual(got_parsed, want_parsed) - self.assertEqual(got_unconsumed, want_unconsumed) - - def test_expect_values_fail(self): - from google.cloud.spanner_dbapi.exceptions import ProgrammingError - from google.cloud.spanner_dbapi.parser import VALUES - from google.cloud.spanner_dbapi.parser import expect - - cases = [ - ("", "VALUES: `` does not start with VALUES"), - ( - "VALUES(%s, %s, (%s, %s))", - "FUNC: `\\(%s, %s\\)\\)` does not begin with `a-zA-z` nor a `_`", - ), - ("VALUES(%s),,", "ARGS: supposed to begin with `\\(` in `,`"), - ] - - for text, wantException in cases: - with self.subTest(text=text): - self.assertRaisesRegex( - ProgrammingError, - wantException, - lambda: expect(text, VALUES), - ) - - def test_as_values(self): - from google.cloud.spanner_dbapi.parser import as_values - - values = (1, 2) - with mock.patch( - "google.cloud.spanner_dbapi.parser.parse_values", - return_value=values, - ): - self.assertEqual(as_values(None), values[1]) diff --git a/tests/unit/spanner_dbapi/test_types.py b/tests/unit/spanner_dbapi/test_types.py deleted file mode 100644 index 4246a43e45..0000000000 --- a/tests/unit/spanner_dbapi/test_types.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import unittest - -from time import timezone - - -class TestTypes(unittest.TestCase): - - TICKS = 1572822862.9782631 + timezone # Sun 03 Nov 2019 23:14:22 UTC - - def test__date_from_ticks(self): - import datetime - - from google.cloud.spanner_dbapi import types - - actual = types._date_from_ticks(self.TICKS) - expected = datetime.date(2019, 11, 3) - - self.assertEqual(actual, expected) - - def test__time_from_ticks(self): - import datetime - - from google.cloud.spanner_dbapi import types - - actual = types._time_from_ticks(self.TICKS) - expected = datetime.time(23, 14, 22) - - self.assertEqual(actual, expected) - - def test__timestamp_from_ticks(self): - import datetime - - from google.cloud.spanner_dbapi import types - - actual = types._timestamp_from_ticks(self.TICKS) - expected = datetime.datetime(2019, 11, 3, 23, 14, 22) - - self.assertEqual(actual, expected) - - def test_type_equal(self): - from google.cloud.spanner_dbapi import types - - self.assertEqual(types.BINARY, "TYPE_CODE_UNSPECIFIED") - self.assertEqual(types.BINARY, "BYTES") - self.assertEqual(types.BINARY, "ARRAY") - self.assertEqual(types.BINARY, "STRUCT") - self.assertNotEqual(types.BINARY, "STRING") - - self.assertEqual(types.NUMBER, "BOOL") - self.assertEqual(types.NUMBER, "INT64") - self.assertEqual(types.NUMBER, "FLOAT64") - self.assertEqual(types.NUMBER, "NUMERIC") - self.assertNotEqual(types.NUMBER, "STRING") - - self.assertEqual(types.DATETIME, "TIMESTAMP") - self.assertEqual(types.DATETIME, "DATE") - self.assertNotEqual(types.DATETIME, "STRING") diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py deleted file mode 100644 index 90e1b7cf04..0000000000 --- a/tests/unit/spanner_dbapi/test_utils.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import unittest - - -class TestUtils(unittest.TestCase): - def test_PeekIterator(self): - from google.cloud.spanner_dbapi.utils import PeekIterator - - cases = [ - ("list", [1, 2, 3, 4, 6, 7], [1, 2, 3, 4, 6, 7]), - ("iter_from_list", iter([1, 2, 3, 4, 6, 7]), [1, 2, 3, 4, 6, 7]), - ("tuple", ("a", 12, 0xFF), ["a", 12, 0xFF]), - ("iter_from_tuple", iter(("a", 12, 0xFF)), ["a", 12, 0xFF]), - ("no_args", (), []), - ] - - for name, data_in, expected in cases: - with self.subTest(name=name): - pitr = PeekIterator(data_in) - actual = list(pitr) - self.assertEqual(actual, expected) - - def test_peekIterator_list_rows_converted_to_tuples(self): - from google.cloud.spanner_dbapi.utils import PeekIterator - - # Cloud Spanner returns results in lists e.g. [result]. - # PeekIterator is used by BaseCursor in its fetch* methods. - # This test ensures that anything passed into PeekIterator - # will be returned as a tuple. - pit = PeekIterator([["a"], ["b"], ["c"], ["d"], ["e"]]) - got = list(pit) - want = [("a",), ("b",), ("c",), ("d",), ("e",)] - self.assertEqual( - got, want, "Rows of type list must be returned as tuples" - ) - - seventeen = PeekIterator([[17]]) - self.assertEqual(list(seventeen), [(17,)]) - - pit = PeekIterator([["%", "%d"]]) - self.assertEqual(next(pit), ("%", "%d")) - - pit = PeekIterator([("Clark", "Kent")]) - self.assertEqual(next(pit), ("Clark", "Kent")) - - def test_peekIterator_nonlist_rows_unconverted(self): - from google.cloud.spanner_dbapi.utils import PeekIterator - - pi = PeekIterator(["a", "b", "c", "d", "e"]) - got = list(pi) - want = ["a", "b", "c", "d", "e"] - self.assertEqual(got, want, "Values should be returned unchanged") - - def test_backtick_unicode(self): - from google.cloud.spanner_dbapi.utils import backtick_unicode - - cases = [ - ("SELECT (1) as foo WHERE 1=1", "SELECT (1) as foo WHERE 1=1"), - ("SELECT (1) as föö", "SELECT (1) as `föö`"), - ("SELECT (1) as `föö`", "SELECT (1) as `föö`"), - ("SELECT (1) as `föö` `umläut", "SELECT (1) as `föö` `umläut"), - ("SELECT (1) as `föö", "SELECT (1) as `föö"), - ] - for sql, want in cases: - with self.subTest(sql=sql): - got = backtick_unicode(sql) - self.assertEqual(got, want)