diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 4b5a0d9652..707bf617af 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -279,7 +279,7 @@ def fetchall(self): self._checksum.consume_result(row) res.append(row) except Aborted: - self._connection.retry_transaction() + self.connection.retry_transaction() return self.fetchall() return res @@ -310,7 +310,7 @@ def fetchmany(self, size=None): except StopIteration: break except Aborted: - self._connection.retry_transaction() + self.connection.retry_transaction() return self.fetchmany(size) return items diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 9f0510c4ab..c83dcb5e10 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -315,6 +315,22 @@ def test_fetchone(self): self.assertEqual(cursor.fetchone(), lst[i]) self.assertIsNone(cursor.fetchone()) + @unittest.skipIf( + sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" + ) + def test_fetchone_w_autocommit(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.autocommit = True + cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() + lst = [1, 2, 3] + cursor._itr = iter(lst) + for i in range(len(lst)): + self.assertEqual(cursor.fetchone(), lst[i]) + self.assertIsNone(cursor.fetchone()) + def test_fetchmany(self): from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -329,6 +345,21 @@ def test_fetchmany(self): result = cursor.fetchmany(len(lst)) self.assertEqual(result, lst[1:]) + def test_fetchmany_w_autocommit(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.autocommit = True + cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + + self.assertEqual(cursor.fetchmany(), [lst[0]]) + + result = cursor.fetchmany(len(lst)) + self.assertEqual(result, lst[1:]) + def test_fetchall(self): from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -339,6 +370,17 @@ def test_fetchall(self): cursor._itr = iter(lst) self.assertEqual(cursor.fetchall(), lst) + def test_fetchall_w_autocommit(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.autocommit = True + cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + self.assertEqual(cursor.fetchall(), lst) + def test_nextset(self): from google.cloud.spanner_dbapi import exceptions @@ -586,3 +628,212 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): cursor.fetchone() run_mock.assert_called_with(statement, retried=True) + + def test_fetchall_retry_aborted(self): + """Check that aborted fetch re-executing transaction.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", + side_effect=(Aborted("Aborted"), iter([])), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + + cursor.fetchall() + + retry_mock.assert_called_with() + + def test_fetchall_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", + side_effect=(Aborted("Aborted"), iter(row)), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + cursor.fetchall() + + run_mock.assert_called_with(statement, retried=True) + + def test_fetchall_retry_aborted_statements_checksums_mismatch(self): + """Check transaction retrying with underlying data being changed.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.exceptions import RetryAborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + row2 = ["updated_field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", + side_effect=(Aborted("Aborted"), iter(row)), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row2], ResultsChecksum()), + ) as run_mock: + + with self.assertRaises(RetryAborted): + cursor.fetchall() + + run_mock.assert_called_with(statement, retried=True) + + def test_fetchmany_retry_aborted(self): + """Check that aborted fetch re-executing transaction.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + + cursor.fetchmany() + + retry_mock.assert_called_with() + + def test_fetchmany_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + cursor.fetchmany(len(row)) + + run_mock.assert_called_with(statement, retried=True) + + def test_fetchmany_retry_aborted_statements_checksums_mismatch(self): + """Check transaction retrying with underlying data being changed.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.exceptions import RetryAborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + row2 = ["updated_field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row2], ResultsChecksum()), + ) as run_mock: + + with self.assertRaises(RetryAborted): + cursor.fetchmany(len(row)) + + run_mock.assert_called_with(statement, retried=True)