diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index f70e7fe669..772ac35032 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -183,6 +183,10 @@ def test_close(self): mock_transaction.rollback = mock_rollback = mock.MagicMock() connection.close() mock_rollback.assert_called_once_with() + connection._transaction = mock.MagicMock() + connection._own_pool = False + connection.close() + self.assertTrue(connection.is_closed) @mock.patch.object(warnings, "warn") def test_commit(self, mock_warn): @@ -379,6 +383,25 @@ def test_run_statement_dont_remember_retried_statements(self): self.assertEqual(len(connection._statements), 0) + def test_run_statement_w_heterogenous_insert_statements(self): + """Check that Connection executed heterogenous insert statements.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + sql = "INSERT INTO T (f1, f2) VALUES (1, 2)" + params = None + param_types = None + + connection = self._make_connection() + + statement = Statement(sql, params, param_types, ResultsChecksum(), True) + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement, retried=True) + + self.assertEqual(len(connection._statements), 0) + def test_run_statement_w_homogeneous_insert_statements(self): """Check that Connection executed homogeneous insert statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -582,3 +605,132 @@ def test_retry_aborted_retry(self): mock.call(statement, retried=True), ) ) + + def test_retry_transaction_raise_max_internal_retries(self): + """Check retrying raise an error of max internal retries.""" + from google.cloud.spanner_dbapi import connection as conn + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + conn.MAX_INTERNAL_RETRIES = 0 + row = ["field1", "field2"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, checksum, False) + connection._statements.append(statement) + + with self.assertRaises(Exception): + connection.retry_transaction() + + conn.MAX_INTERNAL_RETRIES = 50 + + def test_retry_aborted_retry_without_delay(self): + """ + Check that in case of a retried transaction failed, + the connection will retry it once again. + """ + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + connection._statements.append(statement) + + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + side_effect=( + Aborted("Aborted", errors=[metadata_mock]), + ([row], ResultsChecksum()), + ), + ) as retry_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._get_retry_delay", + return_value=False, + ): + connection.retry_transaction() + + retry_mock.assert_has_calls( + ( + mock.call(statement, retried=True), + mock.call(statement, retried=True), + ) + ) + + def test_retry_transaction_w_multiple_statement(self): + """Check retrying an aborted transaction.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = Statement("SELECT 1", [], {}, checksum, False) + statement1 = Statement("SELECT 2", [], {}, checksum, False) + connection._statements.append(statement) + connection._statements.append(statement1) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], retried_checkum), + ) as run_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checkum) + + run_mock.assert_called_with(statement1, retried=True) + + def test_retry_transaction_w_empty_response(self): + """Check retrying an aborted transaction.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + row = [] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.count = 1 + retried_checkum = ResultsChecksum() + + statement = Statement("SELECT 1", [], {}, checksum, False) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=(row, retried_checkum), + ) as run_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checkum) + + run_mock.assert_called_with(statement, retried=True) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index c83dcb5e10..889061cd83 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -140,6 +140,31 @@ def test_execute_autocommit_off(self): self.assertIsInstance(cursor._result_set, mock.MagicMock) self.assertIsInstance(cursor._itr, PeekIterator) + def test_execute_insert_statement_autocommit_off(self): + from google.cloud.spanner_dbapi import parse_utils + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.utils import PeekIterator + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.connection._autocommit = False + cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) + + cursor._checksum = ResultsChecksum() + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_INSERT, + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=(mock.MagicMock(), ResultsChecksum()), + ): + cursor.execute( + sql="INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" + ) + self.assertIsInstance(cursor._result_set, mock.MagicMock) + self.assertIsInstance(cursor._itr, PeekIterator) + def test_execute_statement(self): from google.cloud.spanner_dbapi import parse_utils diff --git a/tests/unit/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py index 4fe94f30a7..76c347d402 100644 --- a/tests/unit/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -85,3 +85,19 @@ def test_backtick_unicode(self): with self.subTest(sql=sql): got = backtick_unicode(sql) self.assertEqual(got, want) + + @unittest.skipIf(skip_condition, skip_message) + def test_StreamedManyResultSets(self): + from google.cloud.spanner_dbapi.utils import StreamedManyResultSets + + cases = [ + ("iter_from_list", iter([1, 2, 3, 4, 6, 7]), [1, 2, 3, 4, 6, 7]), + ("iter_from_tuple", iter(("a", 12, 0xFF)), ["a", 12, 0xFF]), + ] + + for name, data_in, expected in cases: + with self.subTest(name=name): + stream_result = StreamedManyResultSets() + stream_result._iterators.append(data_in) + actual = list(stream_result) + self.assertEqual(actual, expected)