Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: missing method implementation and potential bug in the DatabaseWrapper class #545

Merged
merged 8 commits into from Dec 17, 2020
22 changes: 14 additions & 8 deletions django_spanner/base.py
Expand Up @@ -4,8 +4,10 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import google.cloud.spanner_v1 as spanner

from django.db.backends.base.base import BaseDatabaseWrapper
from google.cloud import spanner_dbapi as Database, spanner_v1 as spanner
from google.cloud import spanner_dbapi

from .client import DatabaseClient
from .creation import DatabaseCreation
Expand Down Expand Up @@ -81,6 +83,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
# special characters for REGEXP_CONTAINS operators (e.g. \, *, _) must be
# escaped on database side.
pattern_esc = r'REPLACE(REPLACE(REPLACE({}, "\\", "\\\\"), "%%", r"\%%"), "_", r"\_")'

# These are all no-ops in favor of using REGEXP_CONTAINS in the customized
# lookups.
pattern_ops = {
Expand All @@ -92,7 +95,7 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"iendswith": "",
}

Database = Database
Database = spanner_dbapi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW this odd "import x as Database; Database = Database" pattern seems to come from sqlite (https://github.com/django/django/blob/e893c0ad8b0b5b0a1e5be3345c287044868effc4/django/db/backends/sqlite3/base.py#L156), and is copied in all the other built in backends.

SchemaEditorClass = DatabaseSchemaEditor
creation_class = DatabaseCreation
features_class = DatabaseFeatures
Expand Down Expand Up @@ -121,10 +124,12 @@ def get_connection_params(self):
}

def get_new_connection(self, conn_params):
return Database.connect(**conn_params)
return self.Database.connect(**conn_params)

def init_connection_state(self):
pass
self.connection.close()
c24t marked this conversation as resolved.
Show resolved Hide resolved
database = self.connection.database
self.connection.__init__(self.instance, database)

def create_cursor(self, name=None):
return self.connection.cursor()
Expand All @@ -134,12 +139,13 @@ def _set_autocommit(self, autocommit):
self.connection.autocommit = autocommit

def is_usable(self):
if self.connection is None:
if self.connection is None or self.connection.is_closed:
return False

try:
# Use a cursor directly, bypassing Django's utilities.
self.connection.cursor().execute("SELECT 1")
except Database.Error:
except self.Database.Error:
return False
else:
return True

return True
6 changes: 3 additions & 3 deletions noxfile.py
Expand Up @@ -62,20 +62,20 @@ def lint_setup_py(session):

def default(session):
# Install all test dependencies, then install this package in-place.
session.install("mock", "pytest", "pytest-cov")
session.install("mock", "pytest", "pytest-cov", "django", "mock-import")
session.install("-e", ".")

# Run py.test against the unit tests.
session.run(
"py.test",
"--quiet",
# "--cov=django_spanner",
"--cov=django_spanner",
"--cov=google.cloud",
"--cov=tests.unit",
"--cov-append",
"--cov-config=.coveragerc",
"--cov-report=",
"--cov-fail-under=90",
"--cov-fail-under=60",
os.path.join("tests", "unit"),
*session.posargs
)
Expand Down
108 changes: 108 additions & 0 deletions tests/unit/django_spanner/test_base.py
@@ -0,0 +1,108 @@
# Copyright 2020 Google LLC
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

import sys
import unittest

from mock_import import mock_import
from unittest import mock


@mock_import()
@unittest.skipIf(sys.version_info < (3, 6), reason="Skipping Python 3.5")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can safely drop 3.5 in tests since we already decided not to support it elsewhere.

class TestBase(unittest.TestCase):
PROJECT = "project"
INSTANCE_ID = "instance_id"
DATABASE_ID = "database_id"
USER_AGENT = "django_spanner/2.2.0a1"
OPTIONS = {"option": "dummy"}

settings_dict = {
"PROJECT": PROJECT,
"INSTANCE": INSTANCE_ID,
"NAME": DATABASE_ID,
"user_agent": USER_AGENT,
"OPTIONS": OPTIONS,
}

def _get_target_class(self):
from django_spanner.base import DatabaseWrapper

return DatabaseWrapper

def _make_one(self, *args, **kwargs):
return self._get_target_class()(*args, **kwargs)

def test_property_instance(self):
settings_dict = {"INSTANCE": "instance"}
db_wrapper = self._make_one(settings_dict=settings_dict)

with mock.patch("django_spanner.base.spanner") as mock_spanner:
mock_spanner.Client = mock_client = mock.MagicMock()
mock_client().instance = mock_instance = mock.MagicMock()
_ = db_wrapper.instance
mock_instance.assert_called_once_with(settings_dict["INSTANCE"])

def test_property__nodb_connection(self):
db_wrapper = self._make_one(None)
with self.assertRaises(NotImplementedError):
db_wrapper._nodb_connection()

def test_get_connection_params(self):
db_wrapper = self._make_one(self.settings_dict)
params = db_wrapper.get_connection_params()

self.assertEqual(params["project"], self.PROJECT)
self.assertEqual(params["instance_id"], self.INSTANCE_ID)
self.assertEqual(params["database_id"], self.DATABASE_ID)
self.assertEqual(params["user_agent"], self.USER_AGENT)
self.assertEqual(params["option"], self.OPTIONS["option"])

def test_get_new_connection(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.Database = mock_database = mock.MagicMock()
mock_database.connect = mock_connect = mock.MagicMock()
conn_params = {"test_param": "dummy"}
db_wrapper.get_new_connection(conn_params)
mock_connect.assert_called_once_with(**conn_params)

def test_init_connection_state(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.close = mock_close = mock.MagicMock()
db_wrapper.init_connection_state()
mock_close.assert_called_once_with()

def test_create_cursor(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.cursor = mock_cursor = mock.MagicMock()
db_wrapper.create_cursor()
mock_cursor.assert_called_once_with()

def test__set_autocommit(self):
db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.autocommit = False
db_wrapper._set_autocommit(True)
self.assertEqual(mock_connection.autocommit, True)

def test_is_usable(self):
from google.cloud.spanner_dbapi.exceptions import Error

db_wrapper = self._make_one(self.settings_dict)
db_wrapper.connection = None
self.assertFalse(db_wrapper.is_usable())

db_wrapper.connection = mock_connection = mock.MagicMock()
mock_connection.is_closed = True
self.assertFalse(db_wrapper.is_usable())

mock_connection.is_closed = False
self.assertTrue(db_wrapper.is_usable())

mock_connection.cursor = mock.MagicMock(side_effect=Error)
self.assertFalse(db_wrapper.is_usable())
2 changes: 1 addition & 1 deletion version.py
Expand Up @@ -4,4 +4,4 @@
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd

__version__ = "2.2.0a1"
__version__ = "3.1.0a1"
c24t marked this conversation as resolved.
Show resolved Hide resolved