Skip to content

Commit

Permalink
fix(db_api): move connection validation into a separate method (#543)
Browse files Browse the repository at this point in the history
Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
  • Loading branch information
Ilya Gurov and larkee committed Sep 7, 2021
1 parent 23b1600 commit 237ae41
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 79 deletions.
38 changes: 26 additions & 12 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -28,7 +28,7 @@
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import InterfaceError
from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

Expand Down Expand Up @@ -349,6 +349,30 @@ def run_statement(self, statement, retried=False):
ResultsChecksum() if retried else statement.checksum,
)

def validate(self):
"""
Execute a minimal request to check if the connection
is valid and the related database is reachable.
Raise an exception in case if the connection is closed,
invalid, target database is not found, or the request result
is incorrect.
:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if the request result is incorrect.
:raises: :class:`google.cloud.exceptions.NotFound`: if the linked instance
or database doesn't exist.
"""
self._raise_if_closed()

with self.database.snapshot() as snapshot:
result = list(snapshot.execute_sql("SELECT 1"))
if result != [[1]]:
raise OperationalError(
"The checking query (SELECT 1) returned an unexpected result: %s. "
"Expected: [[1]]" % result
)

def __enter__(self):
return self

Expand Down Expand Up @@ -399,9 +423,6 @@ def connect(
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
resource.
:raises: :class:`ValueError` in case of given instance/database
doesn't exist.
"""

client_info = ClientInfo(
Expand All @@ -418,14 +439,7 @@ def connect(
)

instance = client.instance(instance_id)
if not instance.exists():
raise ValueError("instance '%s' does not exist." % instance_id)

database = instance.database(database_id, pool=pool)
if not database.exists():
raise ValueError("database '%s' does not exist." % database_id)

conn = Connection(instance, database)
conn = Connection(instance, instance.database(database_id, pool=pool))
if pool is not None:
conn._own_pool = False

Expand Down
7 changes: 7 additions & 0 deletions tests/system/test_dbapi.py
Expand Up @@ -350,3 +350,10 @@ def test_DDL_commit(shared_instance, dbapi_database):

cur.execute("DROP TABLE Singers")
conn.commit()


def test_ping(shared_instance, dbapi_database):
"""Check connection validation method."""
conn = Connection(shared_instance, dbapi_database)
conn.validate()
conn.close()
25 changes: 0 additions & 25 deletions tests/unit/spanner_dbapi/test_connect.py
Expand Up @@ -88,31 +88,6 @@ def test_w_explicit(self, mock_client):
self.assertIs(connection.database, database)
instance.database.assert_called_once_with(DATABASE, pool=pool)

def test_w_instance_not_found(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = mock_client.return_value
instance = client.instance.return_value
instance.exists.return_value = False

with self.assertRaises(ValueError):
connect(INSTANCE, DATABASE)

instance.exists.assert_called_once_with()

def test_w_database_not_found(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = mock_client.return_value
instance = client.instance.return_value
database = instance.database.return_value
database.exists.return_value = False

with self.assertRaises(ValueError):
connect(INSTANCE, DATABASE)

database.exists.assert_called_once_with()

def test_w_credential_file_path(self, mock_client):
from google.cloud.spanner_dbapi import connect
from google.cloud.spanner_dbapi import Connection
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/spanner_dbapi/test_connection.py
Expand Up @@ -624,3 +624,80 @@ def test_retry_transaction_w_empty_response(self):
compare_mock.assert_called_with(checksum, retried_checkum)

run_mock.assert_called_with(statement, retried=True)

def test_validate_ok(self):
def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(return_value=[[1]])

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

connection.validate()
snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")

def test_validate_fail(self):
from google.cloud.spanner_dbapi.exceptions import OperationalError

def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(return_value=[[3]])

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

with self.assertRaises(OperationalError):
connection.validate()

snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")

def test_validate_error(self):
from google.cloud.exceptions import NotFound

def exit_func(self, exc_type, exc_value, traceback):
pass

connection = self._make_connection()

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(side_effect=NotFound("Not found"))

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

with self.assertRaises(NotFound):
connection.validate()

snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")

def test_validate_closed(self):
from google.cloud.spanner_dbapi.exceptions import InterfaceError

connection = self._make_connection()
connection.close()

with self.assertRaises(InterfaceError):
connection.validate()
48 changes: 6 additions & 42 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -332,13 +332,7 @@ def test_executemany_delete_batch_autocommit(self):

sql = "DELETE FROM table WHERE col1 = %s"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
transaction = self._transaction_mock()
Expand Down Expand Up @@ -369,13 +363,7 @@ def test_executemany_update_batch_autocommit(self):

sql = "UPDATE table SET col1 = %s WHERE col2 = %s"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
transaction = self._transaction_mock()
Expand Down Expand Up @@ -418,13 +406,7 @@ def test_executemany_insert_batch_non_autocommit(self):

sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

transaction = self._transaction_mock()

Expand Down Expand Up @@ -461,13 +443,7 @@ def test_executemany_insert_batch_autocommit(self):

sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True

Expand Down Expand Up @@ -510,13 +486,7 @@ def test_executemany_insert_batch_failed(self):
sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""
err_details = "Details here"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
cursor = connection.cursor()
Expand Down Expand Up @@ -546,13 +516,7 @@ def test_executemany_insert_batch_aborted(self):
sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""
err_details = "Aborted details here"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

transaction1 = mock.Mock(committed=False, rolled_back=False)
transaction1.batch_update = mock.Mock(
Expand Down

0 comments on commit 237ae41

Please sign in to comment.