From 4fedcf18a235c226d062ce7e61070477bfd3a107 Mon Sep 17 00:00:00 2001 From: Gurov Ilya Date: Thu, 27 Aug 2020 20:00:20 +0300 Subject: [PATCH] feat: refactor connect() function, cover it with unit tests (#462) --- django_spanner/base.py | 116 +++++++++++----------- setup.cfg | 4 +- spanner_dbapi/__init__.py | 149 ++++++++++++++++++---------- tests/spanner_dbapi/test_connect.py | 97 ++++++++++++++++++ tox.ini | 0 5 files changed, 253 insertions(+), 113 deletions(-) create mode 100644 tests/spanner_dbapi/test_connect.py create mode 100644 tox.ini diff --git a/django_spanner/base.py b/django_spanner/base.py index 70f5b1ba39..f5d2c835d5 100644 --- a/django_spanner/base.py +++ b/django_spanner/base.py @@ -18,62 +18,62 @@ class DatabaseWrapper(BaseDatabaseWrapper): - vendor = 'spanner' - display_name = 'Cloud Spanner' + vendor = "spanner" + display_name = "Cloud Spanner" # Mapping of Field objects to their column types. # https://cloud.google.com/spanner/docs/data-types#date-type data_types = { - 'AutoField': 'INT64', - 'BigAutoField': 'INT64', - 'BinaryField': 'BYTES(MAX)', - 'BooleanField': 'BOOL', - 'CharField': 'STRING(%(max_length)s)', - 'DateField': 'DATE', - 'DateTimeField': 'TIMESTAMP', - 'DecimalField': 'FLOAT64', - 'DurationField': 'INT64', - 'EmailField': 'STRING(%(max_length)s)', - 'FileField': 'STRING(%(max_length)s)', - 'FilePathField': 'STRING(%(max_length)s)', - 'FloatField': 'FLOAT64', - 'IntegerField': 'INT64', - 'BigIntegerField': 'INT64', - 'IPAddressField': 'STRING(15)', - 'GenericIPAddressField': 'STRING(39)', - 'NullBooleanField': 'BOOL', - 'OneToOneField': 'INT64', - 'PositiveIntegerField': 'INT64', - 'PositiveSmallIntegerField': 'INT64', - 'SlugField': 'STRING(%(max_length)s)', - 'SmallAutoField': 'INT64', - 'SmallIntegerField': 'INT64', - 'TextField': 'STRING(MAX)', - 'TimeField': 'TIMESTAMP', - 'UUIDField': 'STRING(32)', + "AutoField": "INT64", + "BigAutoField": "INT64", + "BinaryField": "BYTES(MAX)", + "BooleanField": "BOOL", + "CharField": "STRING(%(max_length)s)", + "DateField": "DATE", + "DateTimeField": "TIMESTAMP", + "DecimalField": "FLOAT64", + "DurationField": "INT64", + "EmailField": "STRING(%(max_length)s)", + "FileField": "STRING(%(max_length)s)", + "FilePathField": "STRING(%(max_length)s)", + "FloatField": "FLOAT64", + "IntegerField": "INT64", + "BigIntegerField": "INT64", + "IPAddressField": "STRING(15)", + "GenericIPAddressField": "STRING(39)", + "NullBooleanField": "BOOL", + "OneToOneField": "INT64", + "PositiveIntegerField": "INT64", + "PositiveSmallIntegerField": "INT64", + "SlugField": "STRING(%(max_length)s)", + "SmallAutoField": "INT64", + "SmallIntegerField": "INT64", + "TextField": "STRING(MAX)", + "TimeField": "TIMESTAMP", + "UUIDField": "STRING(32)", } operators = { - 'exact': '= %s', - 'iexact': 'REGEXP_CONTAINS(%s, %%%%s)', + "exact": "= %s", + "iexact": "REGEXP_CONTAINS(%s, %%%%s)", # contains uses REGEXP_CONTAINS instead of LIKE to allow # DatabaseOperations.prep_for_like_query() to do regular expression # escaping. prep_for_like_query() is called for all the lookups that # use REGEXP_CONTAINS except regex/iregex (see # django.db.models.lookups.PatternLookup). - 'contains': 'REGEXP_CONTAINS(%s, %%%%s)', - 'icontains': 'REGEXP_CONTAINS(%s, %%%%s)', - 'gt': '> %s', - 'gte': '>= %s', - 'lt': '< %s', - 'lte': '<= %s', + "contains": "REGEXP_CONTAINS(%s, %%%%s)", + "icontains": "REGEXP_CONTAINS(%s, %%%%s)", + "gt": "> %s", + "gte": ">= %s", + "lt": "< %s", + "lte": "<= %s", # Using REGEXP_CONTAINS instead of STARTS_WITH and ENDS_WITH for the # same reasoning as described above for 'contains'. - 'startswith': 'REGEXP_CONTAINS(%s, %%%%s)', - 'endswith': 'REGEXP_CONTAINS(%s, %%%%s)', - 'istartswith': 'REGEXP_CONTAINS(%s, %%%%s)', - 'iendswith': 'REGEXP_CONTAINS(%s, %%%%s)', - 'regex': 'REGEXP_CONTAINS(%s, %%%%s)', - 'iregex': 'REGEXP_CONTAINS(%s, %%%%s)', + "startswith": "REGEXP_CONTAINS(%s, %%%%s)", + "endswith": "REGEXP_CONTAINS(%s, %%%%s)", + "istartswith": "REGEXP_CONTAINS(%s, %%%%s)", + "iendswith": "REGEXP_CONTAINS(%s, %%%%s)", + "regex": "REGEXP_CONTAINS(%s, %%%%s)", + "iregex": "REGEXP_CONTAINS(%s, %%%%s)", } # pattern_esc is used to generate SQL pattern lookup clauses when the @@ -81,16 +81,18 @@ class DatabaseWrapper(BaseDatabaseWrapper): # expression or the result of a bilateral transformation). In those cases, # special characters for REGEXP_CONTAINS operators (e.g. \, *, _) must be # escaped on database side. - pattern_esc = r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")' + pattern_esc = ( + r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")' + ) # These are all no-ops in favor of using REGEXP_CONTAINS in the customized # lookups. pattern_ops = { - 'contains': '', - 'icontains': '', - 'startswith': '', - 'istartswith': '', - 'endswith': '', - 'iendswith': '', + "contains": "", + "icontains": "", + "startswith": "", + "istartswith": "", + "endswith": "", + "iendswith": "", } Database = Database @@ -104,7 +106,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): @property def instance(self): - return spanner.Client().instance(self.settings_dict['INSTANCE']) + return spanner.Client().instance(self.settings_dict["INSTANCE"]) @property def _nodb_connection(self): @@ -112,11 +114,11 @@ def _nodb_connection(self): def get_connection_params(self): return { - 'project': self.settings_dict['PROJECT'], - 'instance': self.settings_dict['INSTANCE'], - 'database': self.settings_dict['NAME'], - 'user_agent': 'django_spanner/0.0.1', - **self.settings_dict['OPTIONS'], + "project": self.settings_dict["PROJECT"], + "instance_id": self.settings_dict["INSTANCE"], + "database_id": self.settings_dict["NAME"], + "user_agent": "django_spanner/0.0.1", + **self.settings_dict["OPTIONS"], } def get_new_connection(self, conn_params): @@ -137,7 +139,7 @@ def is_usable(self): return False try: # Use a cursor directly, bypassing Django's utilities. - self.connection.cursor().execute('SELECT 1') + self.connection.cursor().execute("SELECT 1") except Database.Error: return False else: diff --git a/setup.cfg b/setup.cfg index f9e8dff043..43c26175ef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,8 +2,10 @@ max-line-length = 119 [isort] +use_parentheses=True combine_as_imports = true default_section = THIRDPARTY include_trailing_comma = true +force_grid_wrap=0 line_length = 79 -multi_line_output = 5 +multi_line_output = 3 \ No newline at end of file diff --git a/spanner_dbapi/__init__.py b/spanner_dbapi/__init__.py index f5d349a655..58037106ca 100644 --- a/spanner_dbapi/__init__.py +++ b/spanner_dbapi/__init__.py @@ -4,83 +4,122 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from google.cloud import spanner_v1 as spanner +"""Connection-based DB API for Cloud Spanner.""" + +from google.cloud import spanner_v1 from .connection import Connection -# These need to be included in the top-level package for PEP-0249 DB API v2. from .exceptions import ( - DatabaseError, DataError, Error, IntegrityError, InterfaceError, - InternalError, NotSupportedError, OperationalError, ProgrammingError, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, Warning, ) from .parse_utils import get_param_types from .types import ( - BINARY, DATETIME, NUMBER, ROWID, STRING, Binary, Date, DateFromTicks, Time, - TimeFromTicks, Timestamp, TimestampFromTicks, + BINARY, + DATETIME, + NUMBER, + ROWID, + STRING, + Binary, + Date, + DateFromTicks, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, ) from .version import google_client_info -# Globals that MUST be defined ### -apilevel = "2.0" # Implements the Python Database API specification 2.0 version. -# We accept arguments in the format '%s' aka ANSI C print codes. -# as per https://www.python.org/dev/peps/pep-0249/#paramstyle -paramstyle = 'format' -# Threads may share the module but not connections. This is a paranoid threadsafety level, -# but it is necessary for starters to use when debugging failures. Eventually once transactions -# are working properly, we'll update the threadsafety level. +apilevel = "2.0" # supports DP-API 2.0 level. +paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. + +# Threads may share the module, but not connections. This is a paranoid threadsafety +# level, but it is necessary for starters to use when debugging failures. +# Eventually once transactions are working properly, we'll update the +# threadsafety level. threadsafety = 1 -def connect(project=None, instance=None, database=None, credentials_uri=None, user_agent=None): +def connect(instance_id, database_id, project=None, credentials=None, user_agent=None): """ - Connect to Cloud Spanner. + Create a connection to Cloud Spanner database. - Args: - project: The id of a project that already exists. - instance: The id of an instance that already exists. - database: The name of a database that already exists. - credentials_uri: An optional string specifying where to retrieve the service - account JSON for the credentials to connect to Cloud Spanner. + :type instance_id: :class:`str` + :param instance_id: ID of the instance to connect to. - Returns: - The Connection object associated to the Cloud Spanner instance. + :type database_id: :class:`str` + :param database_id: The name of the database to connect to. - Raises: - Error if it encounters any unexpected inputs. - """ - if not project: - raise Error("'project' is required.") - if not instance: - raise Error("'instance' is required.") - if not database: - raise Error("'database' is required.") + :type project: :class:`str` + :param project: (Optional) The ID of the project which owns the + instances, tables and data. If not provided, will + attempt to determine from the environment. - client_kwargs = { - 'project': project, - 'client_info': google_client_info(user_agent), - } - if credentials_uri: - client = spanner.Client.from_service_account_json(credentials_uri, **client_kwargs) - else: - client = spanner.Client(**client_kwargs) + :type credentials: :class:`google.auth.credentials.Credentials` + :param credentials: (Optional) The authorization credentials to attach to requests. + These credentials identify this application to the service. + If none are specified, the client will attempt to ascertain + the credentials from the environment. + + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` + :returns: Connection object associated with the given Cloud Spanner resource. + + :raises: :class:`ValueError` in case of given instance/database + doesn't exist. + """ + client = spanner_v1.Client( + project=project, + credentials=credentials, + client_info=google_client_info(user_agent), + ) - client_instance = client.instance(instance) - if not client_instance.exists(): - raise ProgrammingError("instance '%s' does not exist." % instance) + instance = client.instance(instance_id) + if not instance.exists(): + raise ValueError("instance '%s' does not exist." % instance_id) - db = client_instance.database(database, pool=spanner.pool.BurstyPool()) - if not db.exists(): - raise ProgrammingError("database '%s' does not exist." % database) + database = instance.database(database_id, pool=spanner_v1.pool.BurstyPool()) + if not database.exists(): + raise ValueError("database '%s' does not exist." % database_id) - return Connection(db) + return Connection(database) __all__ = [ - 'DatabaseError', 'DataError', 'Error', 'IntegrityError', 'InterfaceError', - 'InternalError', 'NotSupportedError', 'OperationalError', 'ProgrammingError', - 'Warning', 'DEFAULT_USER_AGENT', 'apilevel', 'connect', 'paramstyle', 'threadsafety', - 'get_param_types', - 'Binary', 'Date', 'DateFromTicks', 'Time', 'TimeFromTicks', 'Timestamp', - 'TimestampFromTicks', - 'BINARY', 'STRING', 'NUMBER', 'DATETIME', 'ROWID', 'TimestampStr', + "DatabaseError", + "DataError", + "Error", + "IntegrityError", + "InterfaceError", + "InternalError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "Warning", + "DEFAULT_USER_AGENT", + "apilevel", + "connect", + "paramstyle", + "threadsafety", + "get_param_types", + "Binary", + "Date", + "DateFromTicks", + "Time", + "TimeFromTicks", + "Timestamp", + "TimestampFromTicks", + "BINARY", + "STRING", + "NUMBER", + "DATETIME", + "ROWID", + "TimestampStr", ] diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py new file mode 100644 index 0000000000..64eaf8cd7d --- /dev/null +++ b/tests/spanner_dbapi/test_connect.py @@ -0,0 +1,97 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""connect() module function unit tests.""" + +import unittest +from unittest import mock + +import google.auth.credentials +from google.api_core.gapic_v1.client_info import ClientInfo +from spanner_dbapi import connect, Connection + + +def _make_credentials(): + class _CredentialsWithScopes( + google.auth.credentials.Credentials, google.auth.credentials.Scoped + ): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class Test_connect(unittest.TestCase): + def test_connect(self): + PROJECT = "test-project" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) + + with mock.patch("spanner_dbapi.spanner_v1.Client") as client_mock: + with mock.patch( + "spanner_dbapi.google_client_info", return_value=CLIENT_INFO + ) as client_info_mock: + + connection = connect( + "test-instance", "test-database", PROJECT, CREDENTIALS, USER_AGENT + ) + + self.assertIsInstance(connection, Connection) + client_info_mock.assert_called_once_with(USER_AGENT) + + client_mock.assert_called_once_with( + project=PROJECT, credentials=CREDENTIALS, client_info=CLIENT_INFO + ) + + def test_instance_not_found(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=False + ) as exists_mock: + + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + exists_mock.assert_called_once_with() + + def test_database_not_found(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=False + ) as exists_mock: + + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + exists_mock.assert_called_once_with() + + def test_connect_instance_id(self): + INSTANCE = "test-instance" + + with mock.patch( + "google.cloud.spanner_v1.client.Client.instance" + ) as instance_mock: + connection = connect(INSTANCE, "test-database") + + instance_mock.assert_called_once_with(INSTANCE) + + self.assertIsInstance(connection, Connection) + + def test_connect_database_id(self): + DATABASE = "test-database" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True + ): + connection = connect("test-instance", DATABASE) + + database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) + + self.assertIsInstance(connection, Connection) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000000..e69de29bb2