From 9c529f3ef143813d8bb0be0e093c659ea7587eca Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Wed, 28 Jul 2021 20:40:23 -0400 Subject: [PATCH] tests: avoid using real credentials unit tests (#432) Closes #431. --- tests/unit/spanner_dbapi/test_connect.py | 181 +++---- tests/unit/spanner_dbapi/test_connection.py | 504 ++++++++------------ tests/unit/spanner_dbapi/test_cursor.py | 177 ++----- 3 files changed, 323 insertions(+), 539 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index a18781ffd1..96dcb20e01 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -20,6 +20,12 @@ import google.auth.credentials +INSTANCE = "test-instance" +DATABASE = "test-database" +PROJECT = "test-project" +USER_AGENT = "user-agent" + + def _make_credentials(): class _CredentialsWithScopes( google.auth.credentials.Credentials, google.auth.credentials.Scoped @@ -29,138 +35,105 @@ class _CredentialsWithScopes( return mock.Mock(spec=_CredentialsWithScopes) +@mock.patch("google.cloud.spanner_v1.Client") class Test_connect(unittest.TestCase): - def test_connect(self): + def test_w_implicit(self, mock_client): from google.cloud.spanner_dbapi import connect from google.cloud.spanner_dbapi import Connection - PROJECT = "test-project" - USER_AGENT = "user-agent" - CREDENTIALS = _make_credentials() - - with mock.patch("google.cloud.spanner_v1.Client") as client_mock: - connection = connect( - "test-instance", - "test-database", - PROJECT, - CREDENTIALS, - user_agent=USER_AGENT, - ) + client = mock_client.return_value + instance = client.instance.return_value + database = instance.database.return_value - self.assertIsInstance(connection, Connection) - - client_mock.assert_called_once_with( - project=PROJECT, credentials=CREDENTIALS, client_info=mock.ANY - ) - - def test_instance_not_found(self): - from google.cloud.spanner_dbapi import connect + connection = connect(INSTANCE, DATABASE) - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=False, - ) as exists_mock: + self.assertIsInstance(connection, Connection) - with self.assertRaises(ValueError): - connect("test-instance", "test-database") + self.assertIs(connection.instance, instance) + client.instance.assert_called_once_with(INSTANCE) - exists_mock.assert_called_once_with() + self.assertIs(connection.database, database) + instance.database.assert_called_once_with(DATABASE, pool=None) + # Datbase constructs its own pool + self.assertIsNotNone(connection.database._pool) - def test_database_not_found(self): + def test_w_explicit(self, mock_client): + from google.cloud.spanner_v1.pool import AbstractSessionPool from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.version import PY_VERSION - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=False, - ) as exists_mock: - - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - exists_mock.assert_called_once_with() + credentials = _make_credentials() + pool = mock.create_autospec(AbstractSessionPool) + client = mock_client.return_value + instance = client.instance.return_value + database = instance.database.return_value - def test_connect_instance_id(self): - from google.cloud.spanner_dbapi import connect - from google.cloud.spanner_dbapi import Connection + connection = connect( + INSTANCE, DATABASE, PROJECT, credentials, pool=pool, user_agent=USER_AGENT, + ) - INSTANCE = "test-instance" + self.assertIsInstance(connection, Connection) - with mock.patch( - "google.cloud.spanner_v1.client.Client.instance" - ) as instance_mock: - connection = connect(INSTANCE, "test-database") + mock_client.assert_called_once_with( + project=PROJECT, credentials=credentials, client_info=mock.ANY + ) + client_info = mock_client.call_args_list[0][1]["client_info"] + self.assertEqual(client_info.user_agent, USER_AGENT) + self.assertEqual(client_info.python_version, PY_VERSION) - instance_mock.assert_called_once_with(INSTANCE) + self.assertIs(connection.instance, instance) + client.instance.assert_called_once_with(INSTANCE) - self.assertIsInstance(connection, Connection) + self.assertIs(connection.database, database) + instance.database.assert_called_once_with(DATABASE, pool=pool) - def test_connect_database_id(self): + def test_w_instance_not_found(self, mock_client): from google.cloud.spanner_dbapi import connect - from google.cloud.spanner_dbapi import Connection - - DATABASE = "test-database" - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - connection = connect("test-instance", DATABASE) + client = mock_client.return_value + instance = client.instance.return_value + instance.exists.return_value = False - database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) + with self.assertRaises(ValueError): + connect(INSTANCE, DATABASE) - self.assertIsInstance(connection, Connection) + instance.exists.assert_called_once_with() - def test_default_sessions_pool(self): + def test_w_database_not_found(self, mock_client): from google.cloud.spanner_dbapi import connect - with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + client = mock_client.return_value + instance = client.instance.return_value + database = instance.database.return_value + database.exists.return_value = False - self.assertIsNotNone(connection.database._pool) + with self.assertRaises(ValueError): + connect(INSTANCE, DATABASE) - def test_sessions_pool(self): + database.exists.assert_called_once_with() + + def test_w_credential_file_path(self, mock_client): from google.cloud.spanner_dbapi import connect - from google.cloud.spanner_v1.pool import FixedSizePool + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.version import PY_VERSION - database_id = "test-database" - pool = FixedSizePool() + credentials_path = "dummy/file/path.json" - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - connect("test-instance", database_id, pool=pool) - database_mock.assert_called_once_with(database_id, pool=pool) + connection = connect( + INSTANCE, + DATABASE, + PROJECT, + credentials=credentials_path, + user_agent=USER_AGENT, + ) - def test_connect_w_credential_file_path(self): - from google.cloud.spanner_dbapi import connect - from google.cloud.spanner_dbapi import Connection + self.assertIsInstance(connection, Connection) - PROJECT = "test-project" - USER_AGENT = "user-agent" - credentials = "dummy/file/path.json" - - with mock.patch( - "google.cloud.spanner_v1.Client.from_service_account_json" - ) as client_mock: - connection = connect( - "test-instance", - "test-database", - PROJECT, - credentials=credentials, - user_agent=USER_AGENT, - ) - - self.assertIsInstance(connection, Connection) - - client_mock.assert_called_once_with( - credentials, project=PROJECT, client_info=mock.ANY - ) + factory = mock_client.from_service_account_json + factory.assert_called_once_with( + credentials_path, project=PROJECT, client_info=mock.ANY, + ) + client_info = factory.call_args_list[0][1]["client_info"] + self.assertEqual(client_info.user_agent, USER_AGENT) + self.assertEqual(client_info.python_version, PY_VERSION) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 772ac35032..48129dcc2f 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -18,6 +18,11 @@ import unittest import warnings +PROJECT = "test-project" +INSTANCE = "test-instance" +DATABASE = "test-database" +USER_AGENT = "user-agent" + def _make_credentials(): from google.auth import credentials @@ -29,78 +34,62 @@ class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): class TestConnection(unittest.TestCase): - - PROJECT = "test-project" - INSTANCE = "test-instance" - DATABASE = "test-database" - USER_AGENT = "user-agent" - CREDENTIALS = _make_credentials() - def _get_client_info(self): from google.api_core.gapic_v1.client_info import ClientInfo - return ClientInfo(user_agent=self.USER_AGENT) + return ClientInfo(user_agent=USER_AGENT) def _make_connection(self): from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_v1.instance import Instance # We don't need a real Client object to test the constructor - instance = Instance(self.INSTANCE, client=None) - database = instance.database(self.DATABASE) + instance = Instance(INSTANCE, client=None) + database = instance.database(DATABASE) return Connection(instance, database) - def test_autocommit_setter_transaction_not_started(self): + @mock.patch("google.cloud.spanner_dbapi.connection.Connection.commit") + def test_autocommit_setter_transaction_not_started(self, mock_commit): connection = self._make_connection() - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.commit" - ) as mock_commit: - connection.autocommit = True - mock_commit.assert_not_called() - self.assertTrue(connection._autocommit) + connection.autocommit = True - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.commit" - ) as mock_commit: - connection.autocommit = False - mock_commit.assert_not_called() - self.assertFalse(connection._autocommit) + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) - def test_autocommit_setter_transaction_started(self): + connection.autocommit = False + mock_commit.assert_not_called() + self.assertFalse(connection._autocommit) + + @mock.patch("google.cloud.spanner_dbapi.connection.Connection.commit") + def test_autocommit_setter_transaction_started(self, mock_commit): connection = self._make_connection() + connection._transaction = mock.Mock(committed=False, rolled_back=False) - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.commit" - ) as mock_commit: - connection._transaction = mock.Mock(committed=False, rolled_back=False) + connection.autocommit = True - connection.autocommit = True - mock_commit.assert_called_once() - self.assertTrue(connection._autocommit) + mock_commit.assert_called_once() + self.assertTrue(connection._autocommit) - def test_autocommit_setter_transaction_started_commited_rolled_back(self): + @mock.patch("google.cloud.spanner_dbapi.connection.Connection.commit") + def test_autocommit_setter_transaction_started_commited_rolled_back( + self, mock_commit + ): connection = self._make_connection() - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.commit" - ) as mock_commit: - connection._transaction = mock.Mock(committed=True, rolled_back=False) + connection._transaction = mock.Mock(committed=True, rolled_back=False) - connection.autocommit = True - mock_commit.assert_not_called() - self.assertTrue(connection._autocommit) + connection.autocommit = True + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) connection.autocommit = False - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.commit" - ) as mock_commit: - connection._transaction = mock.Mock(committed=False, rolled_back=True) + connection._transaction = mock.Mock(committed=False, rolled_back=True) - connection.autocommit = True - mock_commit.assert_not_called() - self.assertTrue(connection._autocommit) + connection.autocommit = True + mock_commit.assert_not_called() + self.assertTrue(connection._autocommit) def test_property_database(self): from google.cloud.spanner_v1.database import Database @@ -116,76 +105,92 @@ def test_property_instance(self): self.assertIsInstance(connection.instance, Instance) self.assertEqual(connection.instance, connection._instance) - def test__session_checkout(self): + @staticmethod + def _make_pool(): + from google.cloud.spanner_v1.pool import AbstractSessionPool + + return mock.create_autospec(AbstractSessionPool) + + @mock.patch("google.cloud.spanner_v1.database.Database") + def test__session_checkout(self, mock_database): from google.cloud.spanner_dbapi import Connection - with mock.patch("google.cloud.spanner_v1.database.Database") as mock_database: - mock_database._pool = mock.MagicMock() - mock_database._pool.get = mock.MagicMock(return_value="db_session_pool") - connection = Connection(self.INSTANCE, mock_database) + pool = self._make_pool() + mock_database._pool = pool + connection = Connection(INSTANCE, mock_database) - connection._session_checkout() - mock_database._pool.get.assert_called_once_with() - self.assertEqual(connection._session, "db_session_pool") + connection._session_checkout() + pool.get.assert_called_once_with() + self.assertEqual(connection._session, pool.get.return_value) - connection._session = "db_session" - connection._session_checkout() - self.assertEqual(connection._session, "db_session") + connection._session = "db_session" + connection._session_checkout() + self.assertEqual(connection._session, "db_session") - def test__release_session(self): + @mock.patch("google.cloud.spanner_v1.database.Database") + def test__release_session(self, mock_database): from google.cloud.spanner_dbapi import Connection - with mock.patch("google.cloud.spanner_v1.database.Database") as mock_database: - mock_database._pool = mock.MagicMock() - mock_database._pool.put = mock.MagicMock() - connection = Connection(self.INSTANCE, mock_database) - connection._session = "session" + pool = self._make_pool() + mock_database._pool = pool + connection = Connection(INSTANCE, mock_database) + connection._session = "session" - connection._release_session() - mock_database._pool.put.assert_called_once_with("session") - self.assertIsNone(connection._session) + connection._release_session() + pool.put.assert_called_once_with("session") + self.assertIsNone(connection._session) def test_transaction_checkout(self): from google.cloud.spanner_dbapi import Connection - connection = Connection(self.INSTANCE, self.DATABASE) - connection._session_checkout = mock_checkout = mock.MagicMock(autospec=True) + connection = Connection(INSTANCE, DATABASE) + mock_checkout = mock.MagicMock(autospec=True) + connection._session_checkout = mock_checkout + connection.transaction_checkout() + mock_checkout.assert_called_once_with() - connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction = mock.MagicMock() mock_transaction.committed = mock_transaction.rolled_back = False + connection._transaction = mock_transaction + self.assertEqual(connection.transaction_checkout(), mock_transaction) connection._autocommit = True self.assertIsNone(connection.transaction_checkout()) - def test_close(self): - from google.cloud.spanner_dbapi import connect, InterfaceError + @mock.patch("google.cloud.spanner_v1.Client") + def test_close(self, mock_client): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import InterfaceError - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") self.assertFalse(connection.is_closed) + connection.close() + self.assertTrue(connection.is_closed) with self.assertRaises(InterfaceError): connection.cursor() - connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction = mock.MagicMock() mock_transaction.committed = mock_transaction.rolled_back = False - mock_transaction.rollback = mock_rollback = mock.MagicMock() + connection._transaction = mock_transaction + + mock_rollback = mock.MagicMock() + mock_transaction.rollback = mock_rollback + connection.close() + mock_rollback.assert_called_once_with() + connection._transaction = mock.MagicMock() connection._own_pool = False connection.close() + self.assertTrue(connection.is_closed) @mock.patch.object(warnings, "warn") @@ -193,13 +198,14 @@ def test_commit(self, mock_warn): from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING - connection = Connection(self.INSTANCE, self.DATABASE) + connection = Connection(INSTANCE, DATABASE) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: connection.commit() - mock_release.assert_not_called() + + mock_release.assert_not_called() connection._transaction = mock_transaction = mock.MagicMock( rolled_back=False, committed=False @@ -210,8 +216,9 @@ def test_commit(self, mock_warn): "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: connection.commit() - mock_commit.assert_called_once_with() - mock_release.assert_called_once_with() + + mock_commit.assert_called_once_with() + mock_release.assert_called_once_with() connection._autocommit = True connection.commit() @@ -224,23 +231,27 @@ def test_rollback(self, mock_warn): from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING - connection = Connection(self.INSTANCE, self.DATABASE) + connection = Connection(INSTANCE, DATABASE) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: connection.rollback() - mock_release.assert_not_called() - connection._transaction = mock_transaction = mock.MagicMock() - mock_transaction.rollback = mock_rollback = mock.MagicMock() + mock_release.assert_not_called() + + mock_transaction = mock.MagicMock() + connection._transaction = mock_transaction + mock_rollback = mock.MagicMock() + mock_transaction.rollback = mock_rollback with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: connection.rollback() - mock_rollback.assert_called_once_with() - mock_release.assert_called_once_with() + + mock_rollback.assert_called_once_with() + mock_release.assert_called_once_with() connection._autocommit = True connection.rollback() @@ -248,101 +259,34 @@ def test_rollback(self, mock_warn): AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 ) - def test_run_prior_DDL_statements(self): + @mock.patch("google.cloud.spanner_v1.database.Database", autospec=True) + def test_run_prior_DDL_statements(self, mock_database): from google.cloud.spanner_dbapi import Connection, InterfaceError - with mock.patch( - "google.cloud.spanner_v1.database.Database", autospec=True - ) as mock_database: - connection = Connection(self.INSTANCE, mock_database) + connection = Connection(INSTANCE, mock_database) - connection.run_prior_DDL_statements() - mock_database.update_ddl.assert_not_called() + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_not_called() - ddl = ["ddl"] - connection._ddl_statements = ddl + ddl = ["ddl"] + connection._ddl_statements = ddl - connection.run_prior_DDL_statements() - mock_database.update_ddl.assert_called_once_with(ddl) + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_called_once_with(ddl) - connection.is_closed = True + connection.is_closed = True - with self.assertRaises(InterfaceError): - connection.run_prior_DDL_statements() + with self.assertRaises(InterfaceError): + connection.run_prior_DDL_statements() - def test_context(self): + def test_as_context_manager(self): connection = self._make_connection() with connection as conn: self.assertEqual(conn, connection) self.assertTrue(connection.is_closed) - def test_connect(self): - from google.cloud.spanner_dbapi import Connection, connect - - with mock.patch("google.cloud.spanner_v1.Client"): - with mock.patch( - "google.api_core.gapic_v1.client_info.ClientInfo", - return_value=self._get_client_info(), - ): - connection = connect( - self.INSTANCE, - self.DATABASE, - self.PROJECT, - self.CREDENTIALS, - self.USER_AGENT, - ) - self.assertIsInstance(connection, Connection) - - def test_connect_instance_not_found(self): - from google.cloud.spanner_dbapi import connect - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=False - ): - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - def test_connect_database_not_found(self): - from google.cloud.spanner_dbapi import connect - - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=False - ): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True - ): - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - def test_default_sessions_pool(self): - from google.cloud.spanner_dbapi import connect - - with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True - ): - connection = connect("test-instance", "test-database") - - self.assertIsNotNone(connection.database._pool) - - def test_sessions_pool(self): - from google.cloud.spanner_dbapi import connect - from google.cloud.spanner_v1.pool import FixedSizePool - - database_id = "test-database" - pool = FixedSizePool() - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True - ): - connect("test-instance", database_id, pool=pool) - database_mock.assert_called_once_with(database_id, pool=pool) - - def test_run_statement_remember_statements(self): + def test_run_statement_wo_retried(self): """Check that Connection remembers executed statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement @@ -352,19 +296,16 @@ def test_run_statement_remember_statements(self): param_types = {"a1": str} connection = self._make_connection() - + connection.transaction_checkout = mock.Mock() statement = Statement(sql, params, param_types, ResultsChecksum(), False) - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" - ): - connection.run_statement(statement) + connection.run_statement(statement) self.assertEqual(connection._statements[0].sql, sql) self.assertEqual(connection._statements[0].params, params) self.assertEqual(connection._statements[0].param_types, param_types) self.assertIsInstance(connection._statements[0].checksum, ResultsChecksum) - def test_run_statement_dont_remember_retried_statements(self): + def test_run_statement_w_retried(self): """Check that Connection doesn't remember re-executed statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement @@ -374,12 +315,9 @@ def test_run_statement_dont_remember_retried_statements(self): param_types = {"a1": str} connection = self._make_connection() - + connection.transaction_checkout = mock.Mock() statement = Statement(sql, params, param_types, ResultsChecksum(), False) - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" - ): - connection.run_statement(statement, retried=True) + connection.run_statement(statement, retried=True) self.assertEqual(len(connection._statements), 0) @@ -393,12 +331,10 @@ def test_run_statement_w_heterogenous_insert_statements(self): param_types = None connection = self._make_connection() - + connection.transaction_checkout = mock.Mock() statement = Statement(sql, params, param_types, ResultsChecksum(), True) - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" - ): - connection.run_statement(statement, retried=True) + + connection.run_statement(statement, retried=True) self.assertEqual(len(connection._statements), 0) @@ -412,16 +348,15 @@ def test_run_statement_w_homogeneous_insert_statements(self): param_types = {"f1": str, "f2": str} connection = self._make_connection() - + connection.transaction_checkout = mock.Mock() statement = Statement(sql, params, param_types, ResultsChecksum(), True) - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" - ): - connection.run_statement(statement, retried=True) + + connection.run_statement(statement, retried=True) self.assertEqual(len(connection._statements), 0) - def test_clear_statements_on_commit(self): + @mock.patch("google.cloud.spanner_v1.transaction.Transaction") + def test_commit_clears_statements(self, mock_transaction): """ Check that all the saved statements are cleared, when the transaction is commited. @@ -432,12 +367,12 @@ def test_clear_statements_on_commit(self): self.assertEqual(len(connection._statements), 2) - with mock.patch("google.cloud.spanner_v1.transaction.Transaction.commit"): - connection.commit() + connection.commit() self.assertEqual(len(connection._statements), 0) - def test_clear_statements_on_rollback(self): + @mock.patch("google.cloud.spanner_v1.transaction.Transaction") + def test_rollback_clears_statements(self, mock_transaction): """ Check that all the saved statements are cleared, when the transaction is roll backed. @@ -448,40 +383,36 @@ def test_clear_statements_on_rollback(self): self.assertEqual(len(connection._statements), 2) - with mock.patch("google.cloud.spanner_v1.transaction.Transaction.commit"): - connection.rollback() + connection.rollback() self.assertEqual(len(connection._statements), 0) - def test_retry_transaction(self): + def test_retry_transaction_w_checksum_match(self): """Check retrying an aborted transaction.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] connection = self._make_connection() - checksum = ResultsChecksum() checksum.consume_result(row) + retried_checkum = ResultsChecksum() + run_mock = connection.run_statement = mock.Mock() + run_mock.return_value = ([row], retried_checkum) statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], retried_checkum), - ) as run_mock: - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() - - compare_mock.assert_called_with(checksum, retried_checkum) + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() - run_mock.assert_called_with(statement, retried=True) + compare_mock.assert_called_with(checksum, retried_checkum) + run_mock.assert_called_with(statement, retried=True) - def test_retry_transaction_checksum_mismatch(self): + def test_retry_transaction_w_checksum_mismatch(self): """ Check retrying an aborted transaction with results checksums mismatch. @@ -497,18 +428,17 @@ def test_retry_transaction_checksum_mismatch(self): checksum = ResultsChecksum() checksum.consume_result(row) retried_checkum = ResultsChecksum() + run_mock = connection.run_statement = mock.Mock() + run_mock.return_value = ([retried_row], retried_checkum) statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([retried_row], retried_checkum), - ): - with self.assertRaises(RetryAborted): - connection.retry_transaction() + with self.assertRaises(RetryAborted): + connection.retry_transaction() - def test_commit_retry_aborted_statements(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_commit_retry_aborted_statements(self, mock_client): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -516,13 +446,8 @@ def test_commit_retry_aborted_statements(self): from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -530,19 +455,15 @@ def test_commit_retry_aborted_statements(self): statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) - connection._transaction = mock.Mock(rolled_back=False, committed=False) + mock_transaction = mock.Mock(rolled_back=False, committed=False) + connection._transaction = mock_transaction + mock_transaction.commit.side_effect = [Aborted("Aborted"), None] + run_mock = connection.run_statement = mock.Mock() + run_mock.return_value = ([row], ResultsChecksum()) - with mock.patch.object( - connection._transaction, "commit", side_effect=(Aborted("Aborted"), None), - ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], ResultsChecksum()), - ) as run_mock: - - connection.commit() + connection.commit() - run_mock.assert_called_with(statement, retried=True) + run_mock.assert_called_with(statement, retried=True) def test_retry_transaction_drop_transaction(self): """ @@ -558,7 +479,8 @@ def test_retry_transaction_drop_transaction(self): connection.retry_transaction() self.assertIsNone(connection._transaction) - def test_retry_aborted_retry(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_retry_aborted_retry(self, mock_client): """ Check that in case of a retried transaction failed, the connection will retry it once again. @@ -570,13 +492,7 @@ def test_retry_aborted_retry(self): row = ["field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -584,27 +500,19 @@ def test_retry_aborted_retry(self): statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) - metadata_mock = mock.Mock() metadata_mock.trailing_metadata.return_value = {} + run_mock = connection.run_statement = mock.Mock() + run_mock.side_effect = [ + Aborted("Aborted", errors=[metadata_mock]), + ([row], ResultsChecksum()), + ] - with mock.patch.object( - connection, - "run_statement", - side_effect=( - Aborted("Aborted", errors=[metadata_mock]), - ([row], ResultsChecksum()), - ), - ) as retry_mock: - - connection.retry_transaction() + connection.retry_transaction() - retry_mock.assert_has_calls( - ( - mock.call(statement, retried=True), - mock.call(statement, retried=True), - ) - ) + run_mock.assert_has_calls( + (mock.call(statement, retried=True), mock.call(statement, retried=True),) + ) def test_retry_transaction_raise_max_internal_retries(self): """Check retrying raise an error of max internal retries.""" @@ -627,7 +535,8 @@ def test_retry_transaction_raise_max_internal_retries(self): conn.MAX_INTERNAL_RETRIES = 50 - def test_retry_aborted_retry_without_delay(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_retry_aborted_retry_without_delay(self, mock_client): """ Check that in case of a retried transaction failed, the connection will retry it once again. @@ -639,13 +548,7 @@ def test_retry_aborted_retry_without_delay(self): row = ["field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -653,29 +556,20 @@ def test_retry_aborted_retry_without_delay(self): statement = Statement("SELECT 1", [], {}, cursor._checksum, False) connection._statements.append(statement) - metadata_mock = mock.Mock() metadata_mock.trailing_metadata.return_value = {} + run_mock = connection.run_statement = mock.Mock() + run_mock.side_effect = [ + Aborted("Aborted", errors=[metadata_mock]), + ([row], ResultsChecksum()), + ] + connection._get_retry_delay = mock.Mock(return_value=False) - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - side_effect=( - Aborted("Aborted", errors=[metadata_mock]), - ([row], ResultsChecksum()), - ), - ) as retry_mock: - with mock.patch( - "google.cloud.spanner_dbapi.connection._get_retry_delay", - return_value=False, - ): - connection.retry_transaction() - - retry_mock.assert_has_calls( - ( - mock.call(statement, retried=True), - mock.call(statement, retried=True), - ) - ) + connection.retry_transaction() + + run_mock.assert_has_calls( + (mock.call(statement, retried=True), mock.call(statement, retried=True),) + ) def test_retry_transaction_w_multiple_statement(self): """Check retrying an aborted transaction.""" @@ -693,19 +587,17 @@ def test_retry_transaction_w_multiple_statement(self): statement1 = Statement("SELECT 2", [], {}, checksum, False) connection._statements.append(statement) connection._statements.append(statement1) + run_mock = connection.run_statement = mock.Mock() + run_mock.return_value = ([row], retried_checkum) with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], retried_checkum), - ) as run_mock: - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() - compare_mock.assert_called_with(checksum, retried_checkum) + compare_mock.assert_called_with(checksum, retried_checkum) - run_mock.assert_called_with(statement1, retried=True) + run_mock.assert_called_with(statement1, retried=True) def test_retry_transaction_w_empty_response(self): """Check retrying an aborted transaction.""" @@ -721,16 +613,14 @@ def test_retry_transaction_w_empty_response(self): statement = Statement("SELECT 1", [], {}, checksum, False) connection._statements.append(statement) + run_mock = connection.run_statement = mock.Mock() + run_mock.return_value = ([row], retried_checkum) with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=(row, retried_checkum), - ) as run_mock: - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() - compare_mock.assert_called_with(checksum, retried_checkum) + compare_mock.assert_called_with(checksum, retried_checkum) - run_mock.assert_called_with(statement, retried=True) + run_mock.assert_called_with(statement, retried=True) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 5b1cf12138..d1a20c2ed2 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -70,16 +70,11 @@ def test_callproc(self): with self.assertRaises(InterfaceError): cursor.callproc(procname=None) - def test_close(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_close(self, mock_client): from google.cloud.spanner_dbapi import connect, InterfaceError - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True - ): - connection = connect(self.INSTANCE, self.DATABASE) + connection = connect(self.INSTANCE, self.DATABASE) cursor = connection.cursor() self.assertFalse(cursor.is_closed) @@ -87,6 +82,7 @@ def test_close(self): cursor.close() self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): cursor.execute("SELECT * FROM database") @@ -276,17 +272,12 @@ def test_execute_internal_server_error(self): with self.assertRaises(OperationalError): cursor.execute(sql="sql") - def test_executemany_on_closed_cursor(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_executemany_on_closed_cursor(self, mock_client): from google.cloud.spanner_dbapi import InterfaceError from google.cloud.spanner_dbapi import connect - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor.close() @@ -294,35 +285,25 @@ def test_executemany_on_closed_cursor(self): with self.assertRaises(InterfaceError): cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ()) - def test_executemany_DLL(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_executemany_DLL(self, mock_client): from google.cloud.spanner_dbapi import connect, ProgrammingError - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() with self.assertRaises(ProgrammingError): cursor.executemany("""DROP DATABASE database_name""", ()) - def test_executemany(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_executemany(self, mock_client): from google.cloud.spanner_dbapi import connect operation = """SELECT * FROM table1 WHERE "col1" = @a1""" params_seq = ((1,), (2,)) - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._result_set = [1, 2, 3] @@ -561,7 +542,8 @@ def test_get_table_column_schema(self): ) self.assertEqual(result, expected) - def test_peek_iterator_aborted(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_peek_iterator_aborted(self, mock_client): """ Checking that an Aborted exception is retried in case it happened while streaming the first element with a PeekIterator. @@ -569,13 +551,7 @@ def test_peek_iterator_aborted(self): from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.connection import connect - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() with mock.patch( @@ -593,7 +569,8 @@ def test_peek_iterator_aborted(self): retry_mock.assert_called_with() - def test_peek_iterator_aborted_autocommit(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_peek_iterator_aborted_autocommit(self, mock_client): """ Checking that an Aborted exception is retried in case it happened while streaming the first element with a PeekIterator in autocommit mode. @@ -601,13 +578,7 @@ def test_peek_iterator_aborted_autocommit(self): from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.connection import connect - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") connection.autocommit = True cursor = connection.cursor() @@ -629,19 +600,14 @@ def test_peek_iterator_aborted_autocommit(self): retry_mock.assert_called_with() - def test_fetchone_retry_aborted(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchone_retry_aborted(self, mock_client): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -658,7 +624,8 @@ def test_fetchone_retry_aborted(self): retry_mock.assert_called_with() - def test_fetchone_retry_aborted_statements(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchone_retry_aborted_statements(self, mock_client): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -666,13 +633,7 @@ def test_fetchone_retry_aborted_statements(self): from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -694,7 +655,8 @@ def test_fetchone_retry_aborted_statements(self): run_mock.assert_called_with(statement, retried=True) - def test_fetchone_retry_aborted_statements_checksums_mismatch(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchone_retry_aborted_statements_checksums_mismatch(self, mock_client): """Check transaction retrying with underlying data being changed.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.exceptions import RetryAborted @@ -705,13 +667,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): row = ["field1", "field2"] row2 = ["updated_field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -734,19 +690,14 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): run_mock.assert_called_with(statement, retried=True) - def test_fetchall_retry_aborted(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchall_retry_aborted(self, mock_client): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -763,7 +714,8 @@ def test_fetchall_retry_aborted(self): retry_mock.assert_called_with() - def test_fetchall_retry_aborted_statements(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchall_retry_aborted_statements(self, mock_client): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -771,13 +723,7 @@ def test_fetchall_retry_aborted_statements(self): from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -798,7 +744,8 @@ def test_fetchall_retry_aborted_statements(self): run_mock.assert_called_with(statement, retried=True) - def test_fetchall_retry_aborted_statements_checksums_mismatch(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchall_retry_aborted_statements_checksums_mismatch(self, mock_client): """Check transaction retrying with underlying data being changed.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.exceptions import RetryAborted @@ -809,13 +756,7 @@ def test_fetchall_retry_aborted_statements_checksums_mismatch(self): row = ["field1", "field2"] row2 = ["updated_field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -838,19 +779,14 @@ def test_fetchall_retry_aborted_statements_checksums_mismatch(self): run_mock.assert_called_with(statement, retried=True) - def test_fetchmany_retry_aborted(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchmany_retry_aborted(self, mock_client): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -867,7 +803,8 @@ def test_fetchmany_retry_aborted(self): retry_mock.assert_called_with() - def test_fetchmany_retry_aborted_statements(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchmany_retry_aborted_statements(self, mock_client): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -875,13 +812,7 @@ def test_fetchmany_retry_aborted_statements(self): from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -903,7 +834,8 @@ def test_fetchmany_retry_aborted_statements(self): run_mock.assert_called_with(statement, retried=True) - def test_fetchmany_retry_aborted_statements_checksums_mismatch(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_fetchmany_retry_aborted_statements_checksums_mismatch(self, mock_client): """Check transaction retrying with underlying data being changed.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.exceptions import RetryAborted @@ -914,13 +846,7 @@ def test_fetchmany_retry_aborted_statements_checksums_mismatch(self): row = ["field1", "field2"] row2 = ["updated_field1", "field2"] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor._checksum = ResultsChecksum() @@ -943,7 +869,8 @@ def test_fetchmany_retry_aborted_statements_checksums_mismatch(self): run_mock.assert_called_with(statement, retried=True) - def test_ddls_with_semicolon(self): + @mock.patch("google.cloud.spanner_v1.Client") + def test_ddls_with_semicolon(self, mock_client): """ Check that one script with several DDL statements separated with semicolons is splitted into several DDLs. @@ -963,13 +890,7 @@ def test_ddls_with_semicolon(self): "DROP TABLE table_name", ] - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", return_value=True, - ): - connection = connect("test-instance", "test-database") + connection = connect("test-instance", "test-database") cursor = connection.cursor() cursor.execute(