From d8453c7e458b0b476b91785d32ba234e333a4b9f Mon Sep 17 00:00:00 2001 From: MF2199 <38331387+mf2199@users.noreply.github.com> Date: Wed, 16 Dec 2020 23:51:40 -0500 Subject: [PATCH] fix: DatabaseWrapper method impl and potential bugfix (#545) --- django_spanner/base.py | 24 +++--- noxfile.py | 18 ++++- tests/unit/django_spanner/test_base.py | 108 +++++++++++++++++++++++++ 3 files changed, 138 insertions(+), 12 deletions(-) create mode 100644 tests/unit/django_spanner/test_base.py diff --git a/django_spanner/base.py b/django_spanner/base.py index b6620b7dfb..35743d0385 100644 --- a/django_spanner/base.py +++ b/django_spanner/base.py @@ -4,8 +4,10 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd +import google.cloud.spanner_v1 as spanner + from django.db.backends.base.base import BaseDatabaseWrapper -from google.cloud import spanner_dbapi as Database, spanner_v1 as spanner +from google.cloud import spanner_dbapi from .client import DatabaseClient from .creation import DatabaseCreation @@ -81,6 +83,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): # special characters for REGEXP_CONTAINS operators (e.g. \, *, _) must be # escaped on database side. pattern_esc = r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")' + # These are all no-ops in favor of using REGEXP_CONTAINS in the customized # lookups. pattern_ops = { @@ -92,7 +95,7 @@ class DatabaseWrapper(BaseDatabaseWrapper): "iendswith": "", } - Database = Database + Database = spanner_dbapi SchemaEditorClass = DatabaseSchemaEditor creation_class = DatabaseCreation features_class = DatabaseFeatures @@ -131,7 +134,7 @@ def get_connection_params(self): **self.settings_dict["OPTIONS"], } - def get_new_connection(self, **conn_params): + def get_new_connection(self, conn_params): """Create a new connection with corresponding connection parameters. :type conn_params: list @@ -145,11 +148,13 @@ def get_new_connection(self, **conn_params): :raises: :class:`ValueError` in case the given instance/database doesn't exist. """ - return Database.connect(**conn_params) + return self.Database.connect(**conn_params) def init_connection_state(self): """Initialize the state of the existing connection.""" - pass + self.connection.close() + database = self.connection.database + self.connection.__init__(self.instance, database) def create_cursor(self, name=None): """Create a new Database cursor. @@ -177,12 +182,13 @@ def is_usable(self): :rtype: bool :returns: True if the connection is open, otherwise False. """ - if self.connection is None: + if self.connection is None or self.connection.is_closed: return False + try: # Use a cursor directly, bypassing Django's utilities. self.connection.cursor().execute("SELECT 1") - except Database.Error: + except self.Database.Error: return False - else: - return True + + return True diff --git a/noxfile.py b/noxfile.py index 53217c5e26..2c1edbe573 100644 --- a/noxfile.py +++ b/noxfile.py @@ -9,10 +9,10 @@ from __future__ import absolute_import -import nox import os import shutil +import nox BLACK_VERSION = "black==19.10b0" BLACK_PATHS = [ @@ -60,12 +60,24 @@ def lint_setup_py(session): def default(session): # Install all test dependencies, then install this package in-place. - session.install("mock", "pytest", "pytest-cov") + session.install( + "django~=2.2", "mock", "mock-import", "pytest", "pytest-cov" + ) session.install("-e", ".") # Run py.test against the unit tests. session.run( - "py.test", "--quiet", os.path.join("tests", "unit"), *session.posargs + "py.test", + "--quiet", + "--cov=django_spanner", + "--cov=google.cloud", + "--cov=tests.unit", + "--cov-append", + "--cov-config=.coveragerc", + "--cov-report=", + "--cov-fail-under=60", + os.path.join("tests", "unit"), + *session.posargs ) diff --git a/tests/unit/django_spanner/test_base.py b/tests/unit/django_spanner/test_base.py new file mode 100644 index 0000000000..c45cd1380d --- /dev/null +++ b/tests/unit/django_spanner/test_base.py @@ -0,0 +1,108 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import sys +import unittest + +from mock_import import mock_import +from unittest import mock + + +@mock_import() +@unittest.skipIf(sys.version_info < (3, 6), reason="Skipping Python 3.5") +class TestBase(unittest.TestCase): + PROJECT = "project" + INSTANCE_ID = "instance_id" + DATABASE_ID = "database_id" + USER_AGENT = "django_spanner/2.2.0a1" + OPTIONS = {"option": "dummy"} + + settings_dict = { + "PROJECT": PROJECT, + "INSTANCE": INSTANCE_ID, + "NAME": DATABASE_ID, + "user_agent": USER_AGENT, + "OPTIONS": OPTIONS, + } + + def _get_target_class(self): + from django_spanner.base import DatabaseWrapper + + return DatabaseWrapper + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_property_instance(self): + settings_dict = {"INSTANCE": "instance"} + db_wrapper = self._make_one(settings_dict=settings_dict) + + with mock.patch("django_spanner.base.spanner") as mock_spanner: + mock_spanner.Client = mock_client = mock.MagicMock() + mock_client().instance = mock_instance = mock.MagicMock() + _ = db_wrapper.instance + mock_instance.assert_called_once_with(settings_dict["INSTANCE"]) + + def test_property__nodb_connection(self): + db_wrapper = self._make_one(None) + with self.assertRaises(NotImplementedError): + db_wrapper._nodb_connection() + + def test_get_connection_params(self): + db_wrapper = self._make_one(self.settings_dict) + params = db_wrapper.get_connection_params() + + self.assertEqual(params["project"], self.PROJECT) + self.assertEqual(params["instance_id"], self.INSTANCE_ID) + self.assertEqual(params["database_id"], self.DATABASE_ID) + self.assertEqual(params["user_agent"], self.USER_AGENT) + self.assertEqual(params["option"], self.OPTIONS["option"]) + + def test_get_new_connection(self): + db_wrapper = self._make_one(self.settings_dict) + db_wrapper.Database = mock_database = mock.MagicMock() + mock_database.connect = mock_connect = mock.MagicMock() + conn_params = {"test_param": "dummy"} + db_wrapper.get_new_connection(conn_params) + mock_connect.assert_called_once_with(**conn_params) + + def test_init_connection_state(self): + db_wrapper = self._make_one(self.settings_dict) + db_wrapper.connection = mock_connection = mock.MagicMock() + mock_connection.close = mock_close = mock.MagicMock() + db_wrapper.init_connection_state() + mock_close.assert_called_once_with() + + def test_create_cursor(self): + db_wrapper = self._make_one(self.settings_dict) + db_wrapper.connection = mock_connection = mock.MagicMock() + mock_connection.cursor = mock_cursor = mock.MagicMock() + db_wrapper.create_cursor() + mock_cursor.assert_called_once_with() + + def test__set_autocommit(self): + db_wrapper = self._make_one(self.settings_dict) + db_wrapper.connection = mock_connection = mock.MagicMock() + mock_connection.autocommit = False + db_wrapper._set_autocommit(True) + self.assertEqual(mock_connection.autocommit, True) + + def test_is_usable(self): + from google.cloud.spanner_dbapi.exceptions import Error + + db_wrapper = self._make_one(self.settings_dict) + db_wrapper.connection = None + self.assertFalse(db_wrapper.is_usable()) + + db_wrapper.connection = mock_connection = mock.MagicMock() + mock_connection.is_closed = True + self.assertFalse(db_wrapper.is_usable()) + + mock_connection.is_closed = False + self.assertTrue(db_wrapper.is_usable()) + + mock_connection.cursor = mock.MagicMock(side_effect=Error) + self.assertFalse(db_wrapper.is_usable())