diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index a28879faba..3569bab605 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -206,7 +206,12 @@ def execute(self, sql, args=None): (self._result_set, self._checksum,) = self.connection.run_statement( statement ) - self._itr = PeekIterator(self._result_set) + while True: + try: + self._itr = PeekIterator(self._result_set) + break + except Aborted: + self.connection.retry_transaction() return if classification == parse_utils.STMT_NON_UPDATING: @@ -352,7 +357,12 @@ def _handle_DQL(self, sql, params): self._result_set = res # Read the first element so that the StreamedResultSet can # return the metadata after a DQL statement. See issue #155. - self._itr = PeekIterator(self._result_set) + while True: + try: + self._itr = PeekIterator(self._result_set) + break + except Aborted: + self.connection.retry_transaction() # Unfortunately, Spanner doesn't seem to send back # information about the number of rows available. self._row_count = _UNSET_COUNT diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 4d5db01eac..57a3375e49 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -549,6 +549,74 @@ def test_get_table_column_schema(self): ) self.assertEqual(result, expected) + def test_peek_iterator_aborted(self): + """ + Checking that an Aborted exception is retried in case it happened + while streaming the first element with a PeekIterator. + """ + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.connection import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + with mock.patch( + "google.cloud.spanner_dbapi.utils.PeekIterator.__init__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=((1, 2, 3), None), + ): + cursor.execute("SELECT * FROM table_name") + + retry_mock.assert_called_with() + + def test_peek_iterator_aborted_autocommit(self): + """ + Checking that an Aborted exception is retried in case it happened while + streaming the first element with a PeekIterator in autocommit mode. + """ + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.connection import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", return_value=True, + ): + connection = connect("test-instance", "test-database") + + connection.autocommit = True + cursor = connection.cursor() + with mock.patch( + "google.cloud.spanner_dbapi.utils.PeekIterator.__init__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=((1, 2, 3), None), + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.snapshot" + ): + cursor.execute("SELECT * FROM table_name") + + retry_mock.assert_called_with() + def test_fetchone_retry_aborted(self): """Check that aborted fetch re-executing transaction.""" from google.api_core.exceptions import Aborted