Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(db_api): move connection validation into a separate method #543

Merged
merged 6 commits into from Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 20 additions & 12 deletions google/cloud/spanner_dbapi/connection.py
Expand Up @@ -28,7 +28,7 @@
from google.cloud.spanner_dbapi.checksum import _compare_checksums
from google.cloud.spanner_dbapi.checksum import ResultsChecksum
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi.exceptions import InterfaceError
from google.cloud.spanner_dbapi.exceptions import InterfaceError, OperationalError
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
from google.cloud.spanner_dbapi.version import PY_VERSION

Expand Down Expand Up @@ -349,6 +349,24 @@ def run_statement(self, statement, retried=False):
ResultsChecksum() if retried else statement.checksum,
)

def validate(self):
"""
Execute a minimal request to check if the connection
is valid and the related database is reachable.

Raise an exception in case if the connection is closed,
invalid, target database is not found, or the request result
is incorrect.

:raises: :class:`InterfaceError`: if this connection is closed.
:raises: :class:`OperationalError`: if the request result is incorrect.
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
"""
self._raise_if_closed()

with self.database.snapshot() as snapshot:
if [[1]] != list(snapshot.execute_sql("SELECT 1")):
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
raise OperationalError("The connection is invalid")
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved

def __enter__(self):
return self

Expand Down Expand Up @@ -399,9 +417,6 @@ def connect(
:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
resource.

:raises: :class:`ValueError` in case of given instance/database
doesn't exist.
"""

client_info = ClientInfo(
Expand All @@ -418,14 +433,7 @@ def connect(
)

instance = client.instance(instance_id)
if not instance.exists():
raise ValueError("instance '%s' does not exist." % instance_id)

database = instance.database(database_id, pool=pool)
if not database.exists():
raise ValueError("database '%s' does not exist." % database_id)

conn = Connection(instance, database)
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
conn = Connection(instance, instance.database(database_id, pool=pool))
if pool is not None:
conn._own_pool = False

Expand Down
25 changes: 0 additions & 25 deletions tests/unit/spanner_dbapi/test_connect.py
Expand Up @@ -88,31 +88,6 @@ def test_w_explicit(self, mock_client):
self.assertIs(connection.database, database)
instance.database.assert_called_once_with(DATABASE, pool=pool)

def test_w_instance_not_found(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = mock_client.return_value
instance = client.instance.return_value
instance.exists.return_value = False

with self.assertRaises(ValueError):
connect(INSTANCE, DATABASE)

instance.exists.assert_called_once_with()

def test_w_database_not_found(self, mock_client):
from google.cloud.spanner_dbapi import connect

client = mock_client.return_value
instance = client.instance.return_value
database = instance.database.return_value
database.exists.return_value = False

with self.assertRaises(ValueError):
connect(INSTANCE, DATABASE)

database.exists.assert_called_once_with()

def test_w_credential_file_path(self, mock_client):
from google.cloud.spanner_dbapi import connect
from google.cloud.spanner_dbapi import Connection
Expand Down
43 changes: 43 additions & 0 deletions tests/unit/spanner_dbapi/test_connection.py
Expand Up @@ -624,3 +624,46 @@ def test_retry_transaction_w_empty_response(self):
compare_mock.assert_called_with(checksum, retried_checkum)

run_mock.assert_called_with(statement, retried=True)

def test_validate_ok(self):
def exit_func(self, exc_type, exc_value, traceback):
if exc_value:
raise exc_value

connection = self._make_connection()

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(return_value=[[1]])

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

connection.validate()

def test_validate_fail(self):
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
from google.cloud.spanner_dbapi.exceptions import OperationalError

def exit_func(self, exc_type, exc_value, traceback):
if exc_value:
raise exc_value

connection = self._make_connection()

# mock snapshot context manager
snapshot_obj = mock.Mock()
snapshot_obj.execute_sql = mock.Mock(return_value=[[3]])

snapshot_ctx = mock.Mock()
snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj)
snapshot_ctx.__exit__ = exit_func
snapshot_method = mock.Mock(return_value=snapshot_ctx)

connection.database.snapshot = snapshot_method

with self.assertRaises(OperationalError):
connection.validate()
48 changes: 6 additions & 42 deletions tests/unit/spanner_dbapi/test_cursor.py
Expand Up @@ -332,13 +332,7 @@ def test_executemany_delete_batch_autocommit(self):

sql = "DELETE FROM table WHERE col1 = %s"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
transaction = self._transaction_mock()
Expand Down Expand Up @@ -369,13 +363,7 @@ def test_executemany_update_batch_autocommit(self):

sql = "UPDATE table SET col1 = %s WHERE col2 = %s"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
transaction = self._transaction_mock()
Expand Down Expand Up @@ -418,13 +406,7 @@ def test_executemany_insert_batch_non_autocommit(self):

sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

transaction = self._transaction_mock()

Expand Down Expand Up @@ -461,13 +443,7 @@ def test_executemany_insert_batch_autocommit(self):

sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True

Expand Down Expand Up @@ -510,13 +486,7 @@ def test_executemany_insert_batch_failed(self):
sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""
err_details = "Details here"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

connection.autocommit = True
cursor = connection.cursor()
Expand Down Expand Up @@ -546,13 +516,7 @@ def test_executemany_insert_batch_aborted(self):
sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)"""
err_details = "Aborted details here"

with mock.patch(
"google.cloud.spanner_v1.instance.Instance.exists", return_value=True
):
with mock.patch(
"google.cloud.spanner_v1.database.Database.exists", return_value=True,
):
connection = connect("test-instance", "test-database")
connection = connect("test-instance", "test-database")

transaction1 = mock.Mock(committed=False, rolled_back=False)
transaction1.batch_update = mock.Mock(
Expand Down