From 50580b485342052a59d830a8428cbeaeab8b828c Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 23 Oct 2020 15:31:01 +0300 Subject: [PATCH 01/18] feat: support aborted transactions internal retry --- google/cloud/spanner_dbapi/checksum.py | 72 ++++++++++++++++++++++++ google/cloud/spanner_dbapi/connection.py | 42 ++++++++++++++ google/cloud/spanner_dbapi/cursor.py | 19 +++++-- tests/spanner_dbapi/test_checksum.py | 44 +++++++++++++++ tests/spanner_dbapi/test_connection.py | 57 +++++++++++++++++++ tests/system/test_system.py | 32 ++++++++++- 6 files changed, 259 insertions(+), 7 deletions(-) create mode 100644 google/cloud/spanner_dbapi/checksum.py create mode 100644 tests/spanner_dbapi/test_checksum.py diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py new file mode 100644 index 0000000000..b7c15bcc63 --- /dev/null +++ b/google/cloud/spanner_dbapi/checksum.py @@ -0,0 +1,72 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""API to calculate checksums of SQL statements results.""" + +import hashlib +import pickle + + +class ResultsChecksum: + """Cumulative checksum. + + Used to calculate a total checksum of all the results + returned by operations executed within transaction. + Includes methods for checksums comparison. + These checksums are used while retrying an aborted + transaction to check if the results of a retried transaction + are equal to the results of the original transaction. + """ + + def __init__(self): + self.checksum = hashlib.sha256() + self.count = 0 # counter of consumed results + + def __len__(self): + """Return the number of consumed results. + + :rtype: :class:`int` + :returns: The number of results. + """ + return self.count + + def __eq__(self, other): + """Check if checksums are equal. + + :type other: :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :param other: Another checksum to compare with this one. + """ + return self.checksum.digest() == other.checksum.digest() + + def consume_result(self, result): + """Add the given result into the checksum. + + :type result: Union[int, list] + :param result: Streamed row or row count from an UPDATE operation. + """ + self.checksum.update(pickle.dumps(result)) + self.count += 1 + + +def _compare_checksums(original, retried): + """Compare the given checksums. + + Raise an error if the given checksums have consumed + the same number of results, but are not equal. + + :type original: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum` + :param original: results checksum of the original transaction. + + :type retried: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum` + :param retried: results checksum of the retried transaction. + + :raises: :exc:`RuntimeError` in case if checksums are not equal. + """ + if original is not None: + if len(retried) == len(original) and retried != original: + raise RuntimeError( + "The underlying data being changed while retrying an aborted transaction." + ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 8907e65c03..be05f48c31 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -11,6 +11,7 @@ from google.cloud import spanner_v1 +from .checksum import ResultsChecksum from .cursor import Cursor from .exceptions import InterfaceError @@ -39,6 +40,9 @@ def __init__(self, instance, database): self._ddl_statements = [] self._transaction = None self._session = None + # SQL statements, which were executed + # within the current transaction + self._statements = [] self.is_closed = False self._autocommit = False @@ -178,6 +182,42 @@ def run_prior_DDL_statements(self): return self.__handle_update_ddl(ddl_statements) + def run_statement(self, sql, params, param_types): + """Run single SQL statement in begun transaction. + + This method is never used in autocommit mode. In + !autocommit mode however it remembers every executed + SQL statement with its parameters. + + :type sql: :class:`str` + :param sql: SQL statement to execute. + + :type params: :class:`dict` + :param params: Params to be used by the given statement. + + :type param_types: :class:`dict` + :param param_types: Statement parameters types description. + + :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`, + :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :returns: Streamed result set of the statement and a + checksum of this statement results. + """ + transaction = self.transaction_checkout() + + statement = { + "sql": sql, + "params": params, + "param_types": param_types, + "checksum": ResultsChecksum(), + } + self._statements.append(statement) + + return ( + transaction.execute_sql(sql, params, param_types=param_types), + statement["checksum"], + ) + def list_tables(self): return self.run_sql_in_snapshot( """ @@ -244,6 +284,7 @@ def commit(self): elif self._transaction: self._transaction.commit() self._release_session() + self._statements = [] def rollback(self): """Rollback all the pending transactions.""" @@ -252,6 +293,7 @@ def rollback(self): elif self._transaction: self._transaction.rollback() self._release_session() + self._statements = [] def __enter__(self): return self diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 95eae50e1a..50dc577693 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -67,6 +67,8 @@ def __init__(self, connection): self._row_count = _UNSET_COUNT self._connection = connection self._is_closed = False + # the currently running SQL statement results checksum + self._checksum = None # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 @@ -101,11 +103,9 @@ def execute(self, sql, args=None): self._run_prior_DDL_statements() if not self._connection.autocommit: - transaction = self._connection.transaction_checkout() - sql, params = sql_pyformat_args_to_spanner(sql, args) - self._res = transaction.execute_sql( + self._res, self._checksum = self._connection.run_statement( sql, params, param_types=get_param_types(params) ) self._itr = PeekIterator(self._res) @@ -305,14 +305,19 @@ def fetchone(self): self._raise_if_closed() try: - return next(self) + res = next(self) + self._checksum.consume_result(res) + return res except StopIteration: return None def fetchall(self): self._raise_if_closed() - return list(self.__iter__()) + res = list(self.__iter__()) + for row in res: + self._checksum.consume_result(row) + return res def fetchmany(self, size=None): """ @@ -335,7 +340,9 @@ def fetchmany(self, size=None): items = [] for i in range(size): try: - items.append(tuple(self.__next__())) + res = next(self) + self._checksum.consume_result(res) + items.append(res) except StopIteration: break diff --git a/tests/spanner_dbapi/test_checksum.py b/tests/spanner_dbapi/test_checksum.py new file mode 100644 index 0000000000..73b0fb063f --- /dev/null +++ b/tests/spanner_dbapi/test_checksum.py @@ -0,0 +1,44 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest + +from google.cloud.spanner_dbapi.checksum import ( + _compare_checksums, + ResultsChecksum, +) + + +class Test_compare_checksums(unittest.TestCase): + def test_no_original_checksum(self): + self.assertIsNone(_compare_checksums(None, ResultsChecksum())) + + def test_equal(self): + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(5) + + self.assertIsNone(_compare_checksums(original, retried)) + + def test_less_results(self): + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + + self.assertIsNone(_compare_checksums(original, retried)) + + def test_mismatch(self): + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(2) + + with self.assertRaises(RuntimeError): + _compare_checksums(original, retried) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index 24260de12e..ba5f3dc2de 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -12,6 +12,7 @@ # import google.cloud.spanner_dbapi.exceptions as dbapi_exceptions from google.cloud.spanner_dbapi import Connection, InterfaceError +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING from google.cloud.spanner_v1.database import Database from google.cloud.spanner_v1.instance import Instance @@ -77,3 +78,59 @@ def test_instance_property(self): with self.assertRaises(AttributeError): connection.instance = None + + def test_run_statement(self): + """Check that Connection remembers executed statements.""" + statement = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement, params, param_types) + + self.assertEqual(connection._statements[0]["sql"], statement) + self.assertEqual(connection._statements[0]["params"], params) + self.assertEqual(connection._statements[0]["param_types"], param_types) + self.assertIsInstance( + connection._statements[0]["checksum"], ResultsChecksum + ) + + def test_clear_statements_on_commit(self): + """ + Check that all the saved statements are + cleared, when the transaction is commited. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.commit() + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_rollback(self): + """ + Check that all the saved statements are + cleared, when the transaction is roll backed. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.rollback() + + self.assertEqual(len(connection._statements), 0) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index f3ee345e15..158092b31c 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -4,8 +4,10 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import unittest +import hashlib +import pickle import os +import unittest from google.api_core import exceptions @@ -289,6 +291,34 @@ def test_rollback_on_connection_closing(self): cursor.close() conn.close() + def test_results_checksum(self): + """Test that results checksum is calculated properly.""" + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES + (1, 'first-name', 'last-name', 'test.email@domen.ru'), + (2, 'first-name2', 'last-name2', 'test.email2@domen.ru') + """ + ) + self.assertEqual(len(conn._statements), 1) + conn.commit() + + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + + self.assertEqual(len(conn._statements), 1) + conn.commit() + + checksum = hashlib.sha256() + checksum.update(pickle.dumps(got_rows[0])) + checksum.update(pickle.dumps(got_rows[1])) + + self.assertEqual(cursor._checksum.checksum.digest(), checksum.digest()) + def clear_table(transaction): """Clear the test table.""" From 481db2a379fca749377a7cc91fee889a62c656c0 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Mon, 26 Oct 2020 12:59:23 +0300 Subject: [PATCH 02/18] add a transaction retry method --- google/cloud/spanner_dbapi/checksum.py | 11 ++-- google/cloud/spanner_dbapi/connection.py | 50 +++++++++------- google/cloud/spanner_dbapi/cursor.py | 27 +++++++-- tests/spanner_dbapi/test_checksum.py | 16 ++++-- tests/spanner_dbapi/test_connection.py | 72 +++++++++++++++++++++++- 5 files changed, 139 insertions(+), 37 deletions(-) diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py index b7c15bcc63..27ccacd094 100644 --- a/google/cloud/spanner_dbapi/checksum.py +++ b/google/cloud/spanner_dbapi/checksum.py @@ -9,6 +9,8 @@ import hashlib import pickle +from google.api_core.exceptions import Aborted + class ResultsChecksum: """Cumulative checksum. @@ -65,8 +67,7 @@ def _compare_checksums(original, retried): :raises: :exc:`RuntimeError` in case if checksums are not equal. """ - if original is not None: - if len(retried) == len(original) and retried != original: - raise RuntimeError( - "The underlying data being changed while retrying an aborted transaction." - ) + if len(retried) == len(original) and retried != original: + raise Aborted( + "The transaction was aborted and could not be retried due to a concurrent modification." + ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index be05f48c31..a8e845a69a 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -6,12 +6,12 @@ """Cloud Spanner DB connection object.""" -from collections import namedtuple import warnings +from collections import namedtuple from google.cloud import spanner_v1 -from .checksum import ResultsChecksum +from .checksum import _compare_checksums, ResultsChecksum from .cursor import Cursor from .exceptions import InterfaceError @@ -108,6 +108,25 @@ def _release_session(self): self.database._pool.put(self._session) self._session = None + def retry_transaction(self): + """Retry the aborted transaction. + + All the statements executed in the transaction + will be re-executed. Results checksums of the original + statements and the retried ones will be compared. + + :raises: :class:`google.api_core.exceptions.Aborted` + If results checksum of the retried statement is + not equal to the checksum of the original one. + """ + for statement in self._statements: + res_iter, retried_checksum = self.run_statement( + statement, retried=True + ) + for res in res_iter: + retried_checksum.consume_result(res) + _compare_checksums(statement["checksum"], retried_checksum) + def transaction_checkout(self): """Get a Cloud Spanner transaction. @@ -182,21 +201,15 @@ def run_prior_DDL_statements(self): return self.__handle_update_ddl(ddl_statements) - def run_statement(self, sql, params, param_types): + def run_statement(self, statement, retried=False): """Run single SQL statement in begun transaction. This method is never used in autocommit mode. In !autocommit mode however it remembers every executed SQL statement with its parameters. - :type sql: :class:`str` - :param sql: SQL statement to execute. - - :type params: :class:`dict` - :param params: Params to be used by the given statement. - - :type param_types: :class:`dict` - :param param_types: Statement parameters types description. + :type statement: :class:`dict` + :param statement: SQL statement to execute. :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`, :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` @@ -204,18 +217,15 @@ def run_statement(self, sql, params, param_types): checksum of this statement results. """ transaction = self.transaction_checkout() - - statement = { - "sql": sql, - "params": params, - "param_types": param_types, - "checksum": ResultsChecksum(), - } self._statements.append(statement) return ( - transaction.execute_sql(sql, params, param_types=param_types), - statement["checksum"], + transaction.execute_sql( + statement["sql"], + statement["params"], + param_types=statement["param_types"], + ), + ResultsChecksum() if retried else statement["checksum"], ) def list_tables(self): diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 50dc577693..5b0e80c5c5 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -7,6 +7,7 @@ """Database cursor API.""" from google.api_core.exceptions import ( + Aborted, AlreadyExists, FailedPrecondition, InternalServerError, @@ -14,6 +15,7 @@ ) from google.cloud.spanner_v1 import param_types +from .checksum import ResultsChecksum from .exceptions import ( IntegrityError, InterfaceError, @@ -105,8 +107,14 @@ def execute(self, sql, args=None): if not self._connection.autocommit: sql, params = sql_pyformat_args_to_spanner(sql, args) + statement = { + "sql": sql, + "params": params, + "param_types": get_param_types(params), + "checksum": ResultsChecksum(), + } self._res, self._checksum = self._connection.run_statement( - sql, params, param_types=get_param_types(params) + statement ) self._itr = PeekIterator(self._res) return @@ -309,14 +317,21 @@ def fetchone(self): self._checksum.consume_result(res) return res except StopIteration: - return None + return + except Aborted: + self._connection.retry_transaction() def fetchall(self): self._raise_if_closed() - res = list(self.__iter__()) - for row in res: - self._checksum.consume_result(row) + res = [] + try: + for row in self.__iter__(): + self._checksum.consume_result(row) + res.append(row) + except Aborted: + self._connection.retry_transaction() + return res def fetchmany(self, size=None): @@ -345,6 +360,8 @@ def fetchmany(self, size=None): items.append(res) except StopIteration: break + except Aborted: + self._connection.retry_transaction() return items diff --git a/tests/spanner_dbapi/test_checksum.py b/tests/spanner_dbapi/test_checksum.py index 73b0fb063f..2f4b74216a 100644 --- a/tests/spanner_dbapi/test_checksum.py +++ b/tests/spanner_dbapi/test_checksum.py @@ -6,6 +6,7 @@ import unittest +from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ( _compare_checksums, ResultsChecksum, @@ -13,9 +14,6 @@ class Test_compare_checksums(unittest.TestCase): - def test_no_original_checksum(self): - self.assertIsNone(_compare_checksums(None, ResultsChecksum())) - def test_equal(self): original = ResultsChecksum() original.consume_result(5) @@ -33,6 +31,16 @@ def test_less_results(self): self.assertIsNone(_compare_checksums(original, retried)) + def test_more_results(self): + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(5) + retried.consume_result(2) + + self.assertIsNone(_compare_checksums(original, retried)) + def test_mismatch(self): original = ResultsChecksum() original.consume_result(5) @@ -40,5 +48,5 @@ def test_mismatch(self): retried = ResultsChecksum() retried.consume_result(2) - with self.assertRaises(RuntimeError): + with self.assertRaises(Aborted): _compare_checksums(original, retried) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index ba5f3dc2de..0ae01225ef 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -11,6 +11,7 @@ # import google.cloud.spanner_dbapi.exceptions as dbapi_exceptions +from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi import Connection, InterfaceError from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING @@ -81,18 +82,25 @@ def test_instance_property(self): def test_run_statement(self): """Check that Connection remembers executed statements.""" - statement = """SELECT 23 FROM table WHERE id = @a1""" + sql = """SELECT 23 FROM table WHERE id = @a1""" params = {"a1": "value"} param_types = {"a1": str} connection = self._make_connection() + statement = { + "sql": sql, + "params": params, + "param_types": param_types, + "checksum": ResultsChecksum(), + } + with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" ): - connection.run_statement(statement, params, param_types) + connection.run_statement(statement) - self.assertEqual(connection._statements[0]["sql"], statement) + self.assertEqual(connection._statements[0]["sql"], sql) self.assertEqual(connection._statements[0]["params"], params) self.assertEqual(connection._statements[0]["param_types"], param_types) self.assertIsInstance( @@ -134,3 +142,61 @@ def test_clear_statements_on_rollback(self): connection.rollback() self.assertEqual(len(connection._statements), 0) + + def test_retry_transaction(self): + """Check retrying an aborted transaction.""" + row = ["field1", "field2"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], retried_checkum), + ) as run_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checkum) + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_checksum_mismatch(self): + """ + Check retrying an aborted transaction + with results checksums mismatch. + """ + row = ["field1", "field2"] + retried_row = ["field3", "field4"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([retried_row], retried_checkum), + ): + with self.assertRaises(Aborted): + connection.retry_transaction() From 313e16df8060c6348b8e72b04e3e1201fd773038 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 27 Oct 2020 12:58:30 +0300 Subject: [PATCH 03/18] add retrying all the statements up to failed one --- google/cloud/spanner_dbapi/checksum.py | 8 +- google/cloud/spanner_dbapi/connection.py | 18 +++- google/cloud/spanner_dbapi/cursor.py | 3 + tests/spanner_dbapi/test_checksum.py | 6 +- tests/spanner_dbapi/test_cursor.py | 108 +++++++++++++++++++++++ 5 files changed, 135 insertions(+), 8 deletions(-) diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py index 27ccacd094..91eafa2ad2 100644 --- a/google/cloud/spanner_dbapi/checksum.py +++ b/google/cloud/spanner_dbapi/checksum.py @@ -56,8 +56,8 @@ def consume_result(self, result): def _compare_checksums(original, retried): """Compare the given checksums. - Raise an error if the given checksums have consumed - the same number of results, but are not equal. + Raise an error if the given checksums has + different length, or are not equal. :type original: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum` :param original: results checksum of the original transaction. @@ -65,9 +65,9 @@ def _compare_checksums(original, retried): :type retried: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum` :param retried: results checksum of the retried transaction. - :raises: :exc:`RuntimeError` in case if checksums are not equal. + :raises: :exc:`google.api_core.exceptions.Aborted` in case if checksums are not equal. """ - if len(retried) == len(original) and retried != original: + if len(retried) != len(original) or retried != original: raise Aborted( "The transaction was aborted and could not be retried due to a concurrent modification." ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index a8e845a69a..ef16aa6807 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -123,8 +123,22 @@ def retry_transaction(self): res_iter, retried_checksum = self.run_statement( statement, retried=True ) - for res in res_iter: - retried_checksum.consume_result(res) + # executing all the completed statements + if statement != self._statements[-1]: + for res in res_iter: + retried_checksum.consume_result(res) + + _compare_checksums(statement["checksum"], retried_checksum) + # executing the failed statement + else: + # streaming up to the failed result + while len(retried_checksum) < len(statement["checksum"]): + try: + res = next(iter(res_iter)) + retried_checksum.consume_result(res) + except StopIteration: + break + _compare_checksums(statement["checksum"], retried_checksum) def transaction_checkout(self): diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 5b0e80c5c5..2cc6dee349 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -320,6 +320,7 @@ def fetchone(self): return except Aborted: self._connection.retry_transaction() + return self.fetchone() def fetchall(self): self._raise_if_closed() @@ -331,6 +332,7 @@ def fetchall(self): res.append(row) except Aborted: self._connection.retry_transaction() + return self.fetchall() return res @@ -362,6 +364,7 @@ def fetchmany(self, size=None): break except Aborted: self._connection.retry_transaction() + return self.fetchmany(size) return items diff --git a/tests/spanner_dbapi/test_checksum.py b/tests/spanner_dbapi/test_checksum.py index 2f4b74216a..c6c2b08afb 100644 --- a/tests/spanner_dbapi/test_checksum.py +++ b/tests/spanner_dbapi/test_checksum.py @@ -29,7 +29,8 @@ def test_less_results(self): retried = ResultsChecksum() - self.assertIsNone(_compare_checksums(original, retried)) + with self.assertRaises(Aborted): + _compare_checksums(original, retried) def test_more_results(self): original = ResultsChecksum() @@ -39,7 +40,8 @@ def test_more_results(self): retried.consume_result(5) retried.consume_result(2) - self.assertIsNone(_compare_checksums(original, retried)) + with self.assertRaises(Aborted): + _compare_checksums(original, retried) def test_mismatch(self): original = ResultsChecksum() diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index 673a95d3e5..4baa409070 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -9,7 +9,9 @@ import unittest from unittest import mock +from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi import connect, InterfaceError +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import ColumnInfo @@ -97,6 +99,112 @@ def test_executemany(self): (mock.call(operation, (1,)), mock.call(operation, (2,))) ) + def test_fetchone_retry_aborted(self): + """Check that aborted fetch re-executing transaction.""" + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + + cursor.fetchone() + + retry_mock.assert_called_once() + + def test_fetchone_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": cursor._checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True) + + def test_fetchone_retry_aborted_statements_checksums_mismatch(self): + """Check transaction retrying with underlying data being changed.""" + row = ["field1", "field2"] + row2 = ["updated_field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": cursor._checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=Aborted("Aborted"), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row2], ResultsChecksum()), + ) as run_mock: + + with self.assertRaises(Aborted): + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True) + class TestColumns(unittest.TestCase): def test_ctor(self): From 4edf6c1663e2e70809f375f2e886043cbd1eff2c Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 27 Oct 2020 13:33:25 +0300 Subject: [PATCH 04/18] resolve conflicts --- google/cloud/spanner_dbapi/connection.py | 34 +++++++++++++++++++-- google/cloud/spanner_dbapi/cursor.py | 39 ++++++++++++------------ tests/spanner_dbapi/test_connect.py | 31 ++++++------------- 3 files changed, 61 insertions(+), 43 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 87fccccd62..de70b7dbe0 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -7,11 +7,12 @@ """DB-API Connection for the Google Cloud Spanner.""" import warnings -from collections import namedtuple from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner +from google.cloud.spanner_dbapi.checksum import _compare_checksums +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT @@ -229,6 +230,33 @@ def run_prior_DDL_statements(self): return self.database.update_ddl(ddl_statements).result() + def run_statement(self, statement, retried=False): + """Run single SQL statement in begun transaction. + + This method is never used in autocommit mode. In + !autocommit mode however it remembers every executed + SQL statement with its parameters. + + :type statement: :class:`dict` + :param statement: SQL statement to execute. + + :rtype: :class:`google.cloud.spanner_v1.streamed.StreamedResultSet`, + :class:`google.cloud.spanner_dbapi.checksum.ResultsChecksum` + :returns: Streamed result set of the statement and a + checksum of this statement results. + """ + transaction = self.transaction_checkout() + self._statements.append(statement) + + return ( + transaction.execute_sql( + statement["sql"], + statement["params"], + param_types=statement["param_types"], + ), + ResultsChecksum() if retried else statement["checksum"], + ) + def __enter__(self): return self @@ -282,11 +310,11 @@ def connect( """ client_info = ClientInfo( - user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION, + user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION ) client = spanner.Client( - project=project, credentials=credentials, client_info=client_info, + project=project, credentials=credentials, client_info=client_info ) instance = client.instance(instance_id) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 2f1df7a933..a28c9f607f 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -6,6 +6,7 @@ """Database cursor for Google Cloud Spanner DB-API.""" +from google.api_core.exceptions import Aborted from google.api_core.exceptions import AlreadyExists from google.api_core.exceptions import FailedPrecondition from google.api_core.exceptions import InternalServerError @@ -14,7 +15,7 @@ from collections import namedtuple from google.cloud import spanner_v1 as spanner - +from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.exceptions import IntegrityError from google.cloud.spanner_dbapi.exceptions import InterfaceError from google.cloud.spanner_dbapi.exceptions import OperationalError @@ -26,6 +27,7 @@ from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner from google.cloud.spanner_dbapi.utils import PeekIterator _UNSET_COUNT = -1 @@ -46,6 +48,8 @@ def __init__(self, connection): self._row_count = _UNSET_COUNT self.connection = connection self._is_closed = False + # the currently running SQL statement results checksum + self._checksum = None # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 @@ -158,16 +162,18 @@ def execute(self, sql, args=None): self.connection.run_prior_DDL_statements() if not self.connection.autocommit: - transaction = self.connection.transaction_checkout() - - sql, params = parse_utils.sql_pyformat_args_to_spanner( - sql, args - ) - - self._result_set = transaction.execute_sql( - sql, params, param_types=get_param_types(params) + sql, params = sql_pyformat_args_to_spanner(sql, args) + + statement = { + "sql": sql, + "params": params, + "param_types": get_param_types(params), + "checksum": ResultsChecksum(), + } + self._res, self._checksum = self.connection.run_statement( + statement ) - self._itr = PeekIterator(self._result_set) + self._itr = PeekIterator(self._res) return if classification == parse_utils.STMT_NON_UPDATING: @@ -213,10 +219,13 @@ def fetchone(self): except StopIteration: return except Aborted: - self._connection.retry_transaction() + self.connection.retry_transaction() return self.fetchone() def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ self._raise_if_closed() res = [] @@ -260,14 +269,6 @@ def fetchmany(self, size=None): return items - def fetchall(self): - """Fetch all (remaining) rows of a query result, returning them as - a sequence of sequences. - """ - self._raise_if_closed() - - return list(self.__iter__()) - def nextset(self): """A no-op, raising an error if the cursor or connection is closed.""" self._raise_if_closed() diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py index fb4d89c373..857e2c66a0 100644 --- a/tests/spanner_dbapi/test_connect.py +++ b/tests/spanner_dbapi/test_connect.py @@ -29,31 +29,20 @@ def test_connect(self): PROJECT = "test-project" USER_AGENT = "user-agent" CREDENTIALS = _make_credentials() - CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) - with mock.patch( - "google.cloud.spanner_dbapi.spanner_v1.Client" - ) as client_mock: - with mock.patch( - "google.cloud.spanner_dbapi.google_client_info", - return_value=CLIENT_INFO, - ) as client_info_mock: - - connection = connect( - "test-instance", - "test-database", - PROJECT, - CREDENTIALS, - user_agent=USER_AGENT, - ) + with mock.patch("google.cloud.spanner_v1.Client") as client_mock: + connection = connect( + "test-instance", + "test-database", + PROJECT, + CREDENTIALS, + user_agent=USER_AGENT, + ) - self.assertIsInstance(connection, Connection) - client_info_mock.assert_called_once_with(USER_AGENT) + self.assertIsInstance(connection, Connection) client_mock.assert_called_once_with( - project=PROJECT, - credentials=CREDENTIALS, - client_info=CLIENT_INFO, + project=PROJECT, credentials=CREDENTIALS, client_info=mock.ANY ) def test_instance_not_found(self): From f87129ab5add0766b0981c4959c9c55123dc7134 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 27 Oct 2020 13:45:10 +0300 Subject: [PATCH 05/18] move unit tests --- tests/spanner_dbapi/test_checksum.py | 54 -------- tests/spanner_dbapi/test_connect.py | 124 ------------------ tests/unit/spanner_dbapi/test_connection.py | 134 +++++++++++++++++++- 3 files changed, 131 insertions(+), 181 deletions(-) delete mode 100644 tests/spanner_dbapi/test_checksum.py delete mode 100644 tests/spanner_dbapi/test_connect.py diff --git a/tests/spanner_dbapi/test_checksum.py b/tests/spanner_dbapi/test_checksum.py deleted file mode 100644 index c6c2b08afb..0000000000 --- a/tests/spanner_dbapi/test_checksum.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import unittest - -from google.api_core.exceptions import Aborted -from google.cloud.spanner_dbapi.checksum import ( - _compare_checksums, - ResultsChecksum, -) - - -class Test_compare_checksums(unittest.TestCase): - def test_equal(self): - original = ResultsChecksum() - original.consume_result(5) - - retried = ResultsChecksum() - retried.consume_result(5) - - self.assertIsNone(_compare_checksums(original, retried)) - - def test_less_results(self): - original = ResultsChecksum() - original.consume_result(5) - - retried = ResultsChecksum() - - with self.assertRaises(Aborted): - _compare_checksums(original, retried) - - def test_more_results(self): - original = ResultsChecksum() - original.consume_result(5) - - retried = ResultsChecksum() - retried.consume_result(5) - retried.consume_result(2) - - with self.assertRaises(Aborted): - _compare_checksums(original, retried) - - def test_mismatch(self): - original = ResultsChecksum() - original.consume_result(5) - - retried = ResultsChecksum() - retried.consume_result(2) - - with self.assertRaises(Aborted): - _compare_checksums(original, retried) diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py deleted file mode 100644 index 857e2c66a0..0000000000 --- a/tests/spanner_dbapi/test_connect.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""connect() module function unit tests.""" - -import unittest -from unittest import mock - -import google.auth.credentials -from google.api_core.gapic_v1.client_info import ClientInfo -from google.cloud.spanner_dbapi import connect, Connection -from google.cloud.spanner_v1.pool import FixedSizePool - - -def _make_credentials(): - class _CredentialsWithScopes( - google.auth.credentials.Credentials, google.auth.credentials.Scoped - ): - pass - - return mock.Mock(spec=_CredentialsWithScopes) - - -class Test_connect(unittest.TestCase): - def test_connect(self): - PROJECT = "test-project" - USER_AGENT = "user-agent" - CREDENTIALS = _make_credentials() - - with mock.patch("google.cloud.spanner_v1.Client") as client_mock: - connection = connect( - "test-instance", - "test-database", - PROJECT, - CREDENTIALS, - user_agent=USER_AGENT, - ) - - self.assertIsInstance(connection, Connection) - - client_mock.assert_called_once_with( - project=PROJECT, credentials=CREDENTIALS, client_info=mock.ANY - ) - - def test_instance_not_found(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=False, - ) as exists_mock: - - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - exists_mock.assert_called_once_with() - - def test_database_not_found(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=False, - ) as exists_mock: - - with self.assertRaises(ValueError): - connect("test-instance", "test-database") - - exists_mock.assert_called_once_with() - - def test_connect_instance_id(self): - INSTANCE = "test-instance" - - with mock.patch( - "google.cloud.spanner_v1.client.Client.instance" - ) as instance_mock: - connection = connect(INSTANCE, "test-database") - - instance_mock.assert_called_once_with(INSTANCE) - - self.assertIsInstance(connection, Connection) - - def test_connect_database_id(self): - DATABASE = "test-database" - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connection = connect("test-instance", DATABASE) - - database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) - - self.assertIsInstance(connection, Connection) - - def test_default_sessions_pool(self): - with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - self.assertIsNotNone(connection.database._pool) - - def test_sessions_pool(self): - database_id = "test-database" - pool = FixedSizePool() - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.database" - ) as database_mock: - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - connect("test-instance", database_id, pool=pool) - database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index d545472c57..d4f702368d 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -80,7 +80,7 @@ def test__session_checkout(self): from google.cloud.spanner_dbapi import Connection with mock.patch( - "google.cloud.spanner_v1.database.Database", + "google.cloud.spanner_v1.database.Database" ) as mock_database: mock_database._pool = mock.MagicMock() mock_database._pool.get = mock.MagicMock( @@ -100,7 +100,7 @@ def test__release_session(self): from google.cloud.spanner_dbapi import Connection with mock.patch( - "google.cloud.spanner_v1.database.Database", + "google.cloud.spanner_v1.database.Database" ) as mock_database: mock_database._pool = mock.MagicMock() mock_database._pool.put = mock.MagicMock() @@ -220,7 +220,7 @@ def test_run_prior_DDL_statements(self): from google.cloud.spanner_dbapi import Connection, InterfaceError with mock.patch( - "google.cloud.spanner_v1.database.Database", autospec=True, + "google.cloud.spanner_v1.database.Database", autospec=True ) as mock_database: connection = Connection(self.INSTANCE, mock_database) @@ -316,3 +316,131 @@ def test_sessions_pool(self): ): connect("test-instance", database_id, pool=pool) database_mock.assert_called_once_with(database_id, pool=pool) + + def test_run_statement(self): + """Check that Connection remembers executed statements.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + + sql = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + statement = { + "sql": sql, + "params": params, + "param_types": param_types, + "checksum": ResultsChecksum(), + } + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement) + + self.assertEqual(connection._statements[0]["sql"], sql) + self.assertEqual(connection._statements[0]["params"], params) + self.assertEqual(connection._statements[0]["param_types"], param_types) + self.assertIsInstance( + connection._statements[0]["checksum"], ResultsChecksum + ) + + def test_clear_statements_on_commit(self): + """ + Check that all the saved statements are + cleared, when the transaction is commited. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.commit() + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_rollback(self): + """ + Check that all the saved statements are + cleared, when the transaction is roll backed. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.rollback() + + self.assertEqual(len(connection._statements), 0) + + def test_retry_transaction(self): + """Check retrying an aborted transaction.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + + row = ["field1", "field2"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], retried_checkum), + ) as run_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checkum) + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_checksum_mismatch(self): + """ + Check retrying an aborted transaction + with results checksums mismatch. + """ + from google.cloud.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + + row = ["field1", "field2"] + retried_row = ["field3", "field4"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([retried_row], retried_checkum), + ): + with self.assertRaises(Aborted): + connection.retry_transaction() From 49e17bea9304ce054f3a00963d43e067152375c9 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 27 Oct 2020 13:56:37 +0300 Subject: [PATCH 06/18] fix tests --- google/cloud/spanner_dbapi/cursor.py | 4 +- tests/unit/spanner_dbapi/test_checksum.py | 54 ++++++++ tests/unit/spanner_dbapi/test_connect.py | 124 ++++++++++++++++++ tests/unit/spanner_dbapi/test_connection.py | 2 +- tests/unit/spanner_dbapi/test_cursor.py | 135 +++++++++++++++++++- 5 files changed, 312 insertions(+), 7 deletions(-) create mode 100644 tests/unit/spanner_dbapi/test_checksum.py create mode 100644 tests/unit/spanner_dbapi/test_connect.py diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index a28c9f607f..930335fb0b 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -170,10 +170,10 @@ def execute(self, sql, args=None): "param_types": get_param_types(params), "checksum": ResultsChecksum(), } - self._res, self._checksum = self.connection.run_statement( + self._result_set, self._checksum = self.connection.run_statement( statement ) - self._itr = PeekIterator(self._res) + self._itr = PeekIterator(self._result_set) return if classification == parse_utils.STMT_NON_UPDATING: diff --git a/tests/unit/spanner_dbapi/test_checksum.py b/tests/unit/spanner_dbapi/test_checksum.py new file mode 100644 index 0000000000..c6c2b08afb --- /dev/null +++ b/tests/unit/spanner_dbapi/test_checksum.py @@ -0,0 +1,54 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +import unittest + +from google.api_core.exceptions import Aborted +from google.cloud.spanner_dbapi.checksum import ( + _compare_checksums, + ResultsChecksum, +) + + +class Test_compare_checksums(unittest.TestCase): + def test_equal(self): + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(5) + + self.assertIsNone(_compare_checksums(original, retried)) + + def test_less_results(self): + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + + with self.assertRaises(Aborted): + _compare_checksums(original, retried) + + def test_more_results(self): + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(5) + retried.consume_result(2) + + with self.assertRaises(Aborted): + _compare_checksums(original, retried) + + def test_mismatch(self): + original = ResultsChecksum() + original.consume_result(5) + + retried = ResultsChecksum() + retried.consume_result(2) + + with self.assertRaises(Aborted): + _compare_checksums(original, retried) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py new file mode 100644 index 0000000000..857e2c66a0 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -0,0 +1,124 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""connect() module function unit tests.""" + +import unittest +from unittest import mock + +import google.auth.credentials +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud.spanner_dbapi import connect, Connection +from google.cloud.spanner_v1.pool import FixedSizePool + + +def _make_credentials(): + class _CredentialsWithScopes( + google.auth.credentials.Credentials, google.auth.credentials.Scoped + ): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class Test_connect(unittest.TestCase): + def test_connect(self): + PROJECT = "test-project" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + + with mock.patch("google.cloud.spanner_v1.Client") as client_mock: + connection = connect( + "test-instance", + "test-database", + PROJECT, + CREDENTIALS, + user_agent=USER_AGENT, + ) + + self.assertIsInstance(connection, Connection) + + client_mock.assert_called_once_with( + project=PROJECT, credentials=CREDENTIALS, client_info=mock.ANY + ) + + def test_instance_not_found(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=False, + ) as exists_mock: + + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + exists_mock.assert_called_once_with() + + def test_database_not_found(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=False, + ) as exists_mock: + + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + exists_mock.assert_called_once_with() + + def test_connect_instance_id(self): + INSTANCE = "test-instance" + + with mock.patch( + "google.cloud.spanner_v1.client.Client.instance" + ) as instance_mock: + connection = connect(INSTANCE, "test-database") + + instance_mock.assert_called_once_with(INSTANCE) + + self.assertIsInstance(connection, Connection) + + def test_connect_database_id(self): + DATABASE = "test-database" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connection = connect("test-instance", DATABASE) + + database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) + + self.assertIsInstance(connection, Connection) + + def test_default_sessions_pool(self): + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertIsNotNone(connection.database._pool) + + def test_sessions_pool(self): + database_id = "test-database" + pool = FixedSizePool() + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connect("test-instance", database_id, pool=pool) + database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index d4f702368d..36096807b0 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -419,7 +419,7 @@ def test_retry_transaction_checksum_mismatch(self): Check retrying an aborted transaction with results checksums mismatch. """ - from google.cloud.api_core.exceptions import Aborted + from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum row = ["field1", "field2"] diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 09288df94e..5b0e8eda0f 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -94,7 +94,7 @@ def test_do_execute_update(self): def run_helper(ret_value): transaction.execute_update.return_value = ret_value res = cursor._do_execute_update( - transaction=transaction, sql="sql", params=None, + transaction=transaction, sql="sql", params=None ) return res @@ -286,17 +286,25 @@ def test_executemany(self): ) def test_fetchone(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [1, 2, 3] cursor._itr = iter(lst) + for i in range(len(lst)): self.assertEqual(cursor.fetchone(), lst[i]) + self.assertIsNone(cursor.fetchone()) def test_fetchmany(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) @@ -306,8 +314,11 @@ def test_fetchmany(self): self.assertEqual(result, lst[1:]) def test_fetchall(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) + cursor._checksum = ResultsChecksum() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) self.assertEqual(cursor.fetchall(), lst) @@ -442,9 +453,7 @@ def test_get_table_column_schema(self): spanner_type = "spanner_type" rows = [(column_name, is_nullable, spanner_type)] expected = { - column_name: ColumnDetails( - null_ok=True, spanner_type=spanner_type, - ) + column_name: ColumnDetails(null_ok=True, spanner_type=spanner_type) } with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", @@ -458,3 +467,121 @@ def test_get_table_column_schema(self): param_types={"table_name": param_types.STRING}, ) self.assertEqual(result, expected) + + def test_fetchone_retry_aborted(self): + """Check that aborted fetch re-executing transaction.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + ) as retry_mock: + + cursor.fetchone() + + retry_mock.assert_called_once() + + def test_fetchone_retry_aborted_statements(self): + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + """Check that retried transaction executing the same statements.""" + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": cursor._checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True) + + def test_fetchone_retry_aborted_statements_checksums_mismatch(self): + """Check transaction retrying with underlying data being changed.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + row = ["field1", "field2"] + row2 = ["updated_field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": cursor._checksum, + } + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=Aborted("Aborted"), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row2], ResultsChecksum()), + ) as run_mock: + + with self.assertRaises(Aborted): + cursor.fetchone() + + run_mock.assert_called_with(statement, retried=True) From a8158b3da98aa3b5805dd7cbbf15ea049b39008c Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 28 Oct 2020 10:58:38 +0300 Subject: [PATCH 07/18] small fixes --- google/cloud/spanner_dbapi/checksum.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py index 91eafa2ad2..22aaa11548 100644 --- a/google/cloud/spanner_dbapi/checksum.py +++ b/google/cloud/spanner_dbapi/checksum.py @@ -56,18 +56,17 @@ def consume_result(self, result): def _compare_checksums(original, retried): """Compare the given checksums. - Raise an error if the given checksums has - different length, or are not equal. + Raise an error if the given checksums are not equal. - :type original: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum` + :type original: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum` :param original: results checksum of the original transaction. - :type retried: :class:`~google.cloud.spanner_v1.transaction.ResultsChecksum` + :type retried: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum` :param retried: results checksum of the retried transaction. :raises: :exc:`google.api_core.exceptions.Aborted` in case if checksums are not equal. """ - if len(retried) != len(original) or retried != original: + if retried != original: raise Aborted( "The transaction was aborted and could not be retried due to a concurrent modification." ) From ccf5385284f97ddc824dcfa9d5814336ac44ca20 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 28 Oct 2020 11:16:07 +0300 Subject: [PATCH 08/18] fix lint and unit tests --- google/cloud/spanner_dbapi/cursor.py | 7 ++++--- tests/unit/spanner_dbapi/test_checksum.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 930335fb0b..1c7b99c912 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -170,9 +170,10 @@ def execute(self, sql, args=None): "param_types": get_param_types(params), "checksum": ResultsChecksum(), } - self._result_set, self._checksum = self.connection.run_statement( - statement - ) + ( + self._result_set, + self._checksum, + ) = self.connection.run_statement(statement) self._itr = PeekIterator(self._result_set) return diff --git a/tests/unit/spanner_dbapi/test_checksum.py b/tests/unit/spanner_dbapi/test_checksum.py index c6c2b08afb..486ba0c036 100644 --- a/tests/unit/spanner_dbapi/test_checksum.py +++ b/tests/unit/spanner_dbapi/test_checksum.py @@ -7,14 +7,13 @@ import unittest from google.api_core.exceptions import Aborted -from google.cloud.spanner_dbapi.checksum import ( - _compare_checksums, - ResultsChecksum, -) class Test_compare_checksums(unittest.TestCase): def test_equal(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + original = ResultsChecksum() original.consume_result(5) @@ -24,6 +23,9 @@ def test_equal(self): self.assertIsNone(_compare_checksums(original, retried)) def test_less_results(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + original = ResultsChecksum() original.consume_result(5) @@ -33,6 +35,9 @@ def test_less_results(self): _compare_checksums(original, retried) def test_more_results(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + original = ResultsChecksum() original.consume_result(5) @@ -44,6 +49,9 @@ def test_more_results(self): _compare_checksums(original, retried) def test_mismatch(self): + from google.cloud.spanner_dbapi.checksum import _compare_checksums + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + original = ResultsChecksum() original.consume_result(5) From a537628e152efff3e064a47cfb2f878874fc3123 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 28 Oct 2020 11:27:31 +0300 Subject: [PATCH 09/18] fix lint and unit tests --- tests/unit/spanner_dbapi/test_connect.py | 1 - tests/unit/spanner_dbapi/test_cursor.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 857e2c66a0..f4506160b0 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -10,7 +10,6 @@ from unittest import mock import google.auth.credentials -from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud.spanner_dbapi import connect, Connection from google.cloud.spanner_v1.pool import FixedSizePool diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 5b0e8eda0f..1663974cf3 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -497,7 +497,7 @@ def test_fetchone_retry_aborted(self): cursor.fetchone() - retry_mock.assert_called_once() + retry_mock.assert_called_with() def test_fetchone_retry_aborted_statements(self): from google.api_core.exceptions import Aborted From 0b1a6410a8d809f541eaf437bbf89fd0759dc7fc Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 28 Oct 2020 11:46:20 +0300 Subject: [PATCH 10/18] fix imports --- tests/unit/spanner_dbapi/test_connect.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index f4506160b0..5d545d7d90 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -10,8 +10,6 @@ from unittest import mock import google.auth.credentials -from google.cloud.spanner_dbapi import connect, Connection -from google.cloud.spanner_v1.pool import FixedSizePool def _make_credentials(): @@ -25,6 +23,9 @@ class _CredentialsWithScopes( class Test_connect(unittest.TestCase): def test_connect(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + PROJECT = "test-project" USER_AGENT = "user-agent" CREDENTIALS = _make_credentials() @@ -45,6 +46,8 @@ def test_connect(self): ) def test_instance_not_found(self): + from google.cloud.spanner_dbapi import connect + with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=False, @@ -56,6 +59,8 @@ def test_instance_not_found(self): exists_mock.assert_called_once_with() def test_database_not_found(self): + from google.cloud.spanner_dbapi import connect + with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, @@ -71,6 +76,9 @@ def test_database_not_found(self): exists_mock.assert_called_once_with() def test_connect_instance_id(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + INSTANCE = "test-instance" with mock.patch( @@ -83,6 +91,9 @@ def test_connect_instance_id(self): self.assertIsInstance(connection, Connection) def test_connect_database_id(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_dbapi import Connection + DATABASE = "test-database" with mock.patch( @@ -99,6 +110,8 @@ def test_connect_database_id(self): self.assertIsInstance(connection, Connection) def test_default_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", @@ -109,6 +122,9 @@ def test_default_sessions_pool(self): self.assertIsNotNone(connection.database._pool) def test_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.pool import FixedSizePool + database_id = "test-database" pool = FixedSizePool() From 0e2ca3eeaa5584bcb7cccc31d5c533d3f08f1c70 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 28 Oct 2020 16:35:13 +0300 Subject: [PATCH 11/18] add aborted commit retry --- google/cloud/spanner_dbapi/connection.py | 11 ++++-- tests/unit/spanner_dbapi/test_connection.py | 44 +++++++++++++++++++++ tests/unit/spanner_dbapi/test_cursor.py | 2 +- 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index de70b7dbe0..67d73c9b8a 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -8,6 +8,7 @@ import warnings +from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner @@ -198,9 +199,13 @@ def commit(self): if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: - self._transaction.commit() - self._release_session() - self._statements = [] + try: + self._transaction.commit() + self._release_session() + self._statements = [] + except Aborted: + self.retry_transaction() + self.commit() def rollback(self): """Rolls back any pending transaction. diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 36096807b0..cc7def0287 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -444,3 +444,47 @@ def test_retry_transaction_checksum_mismatch(self): ): with self.assertRaises(Aborted): connection.retry_transaction() + + def test_commit_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = { + "sql": "SELECT 1", + "params": [], + "param_types": {}, + "checksum": cursor._checksum, + } + connection._statements.append(statement) + connection._transaction = mock.Mock() + + with mock.patch.object( + connection._transaction, + "commit", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + connection.commit() + + run_mock.assert_called_with(statement, retried=True) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 1663974cf3..04046ad7d5 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -500,11 +500,11 @@ def test_fetchone_retry_aborted(self): retry_mock.assert_called_with() def test_fetchone_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - """Check that retried transaction executing the same statements.""" row = ["field1", "field2"] with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", From a4ffab5f23e1177c59359b6aa30c74e95a2d912a Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 29 Oct 2020 12:00:07 +0300 Subject: [PATCH 12/18] add retry inside retry, cleanup tests --- google/cloud/spanner_dbapi/connection.py | 22 +++++--- google/cloud/spanner_dbapi/cursor.py | 15 ++--- tests/unit/spanner_dbapi/test_connection.py | 61 +++++++++++---------- tests/unit/spanner_dbapi/test_cursor.py | 26 +++++---- 4 files changed, 68 insertions(+), 56 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 67d73c9b8a..5c877e554f 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -132,18 +132,18 @@ def retry_transaction(self): for res in res_iter: retried_checksum.consume_result(res) - _compare_checksums(statement["checksum"], retried_checksum) + _compare_checksums(statement.checksum, retried_checksum) # executing the failed statement else: # streaming up to the failed result - while len(retried_checksum) < len(statement["checksum"]): + while len(retried_checksum) < len(statement.checksum): try: res = next(iter(res_iter)) retried_checksum.consume_result(res) except StopIteration: break - _compare_checksums(statement["checksum"], retried_checksum) + _compare_checksums(statement.checksum, retried_checksum) def transaction_checkout(self): """Get a Cloud Spanner transaction. @@ -204,7 +204,13 @@ def commit(self): self._release_session() self._statements = [] except Aborted: - self.retry_transaction() + while True: + try: + self.retry_transaction() + break + except Aborted: + pass + self.commit() def rollback(self): @@ -255,11 +261,11 @@ def run_statement(self, statement, retried=False): return ( transaction.execute_sql( - statement["sql"], - statement["params"], - param_types=statement["param_types"], + statement.sql, + statement.params, + param_types=statement.param_types, ), - ResultsChecksum() if retried else statement["checksum"], + ResultsChecksum() if retried else statement.checksum, ) def __enter__(self): diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 1c7b99c912..f953634a8c 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -33,6 +33,7 @@ _UNSET_COUNT = -1 ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) +Statement = namedtuple("Statement", "sql, params, param_types, checksum") class Cursor(object): @@ -164,12 +165,12 @@ def execute(self, sql, args=None): if not self.connection.autocommit: sql, params = sql_pyformat_args_to_spanner(sql, args) - statement = { - "sql": sql, - "params": params, - "param_types": get_param_types(params), - "checksum": ResultsChecksum(), - } + statement = Statement( + sql, + params, + get_param_types(params), + ResultsChecksum(), + ) ( self._result_set, self._checksum, @@ -231,7 +232,7 @@ def fetchall(self): res = [] try: - for row in self.__iter__(): + for row in self: self._checksum.consume_result(row) res.append(row) except Aborted: diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index cc7def0287..4b388def9f 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -320,6 +320,7 @@ def test_sessions_pool(self): def test_run_statement(self): """Check that Connection remembers executed statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement sql = """SELECT 23 FROM table WHERE id = @a1""" params = {"a1": "value"} @@ -327,23 +328,22 @@ def test_run_statement(self): connection = self._make_connection() - statement = { - "sql": sql, - "params": params, - "param_types": param_types, - "checksum": ResultsChecksum(), - } - + statement = Statement( + sql, + params, + param_types, + ResultsChecksum(), + ) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" ): connection.run_statement(statement) - self.assertEqual(connection._statements[0]["sql"], sql) - self.assertEqual(connection._statements[0]["params"], params) - self.assertEqual(connection._statements[0]["param_types"], param_types) + self.assertEqual(connection._statements[0].sql, sql) + self.assertEqual(connection._statements[0].params, params) + self.assertEqual(connection._statements[0].param_types, param_types) self.assertIsInstance( - connection._statements[0]["checksum"], ResultsChecksum + connection._statements[0].checksum, ResultsChecksum ) def test_clear_statements_on_commit(self): @@ -385,6 +385,7 @@ def test_clear_statements_on_rollback(self): def test_retry_transaction(self): """Check retrying an aborted transaction.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] connection = self._make_connection() @@ -393,12 +394,12 @@ def test_retry_transaction(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = { - "sql": "SELECT 1", - "params": [], - "param_types": {}, - "checksum": checksum, - } + statement = Statement( + "SELECT 1", + [], + {}, + checksum, + ) connection._statements.append(statement) with mock.patch( @@ -421,6 +422,7 @@ def test_retry_transaction_checksum_mismatch(self): """ from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] retried_row = ["field3", "field4"] @@ -430,12 +432,12 @@ def test_retry_transaction_checksum_mismatch(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = { - "sql": "SELECT 1", - "params": [], - "param_types": {}, - "checksum": checksum, - } + statement = Statement( + "SELECT 1", + [], + {}, + checksum, + ) connection._statements.append(statement) with mock.patch( @@ -450,6 +452,7 @@ def test_commit_retry_aborted_statements(self): from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] with mock.patch( @@ -466,12 +469,12 @@ def test_commit_retry_aborted_statements(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = { - "sql": "SELECT 1", - "params": [], - "param_types": {}, - "checksum": cursor._checksum, - } + statement = Statement( + "SELECT 1", + [], + {}, + cursor._checksum, + ) connection._statements.append(statement) connection._transaction = mock.Mock() diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 04046ad7d5..5f3ca0b63b 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -504,6 +504,7 @@ def test_fetchone_retry_aborted_statements(self): from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] with mock.patch( @@ -520,12 +521,12 @@ def test_fetchone_retry_aborted_statements(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = { - "sql": "SELECT 1", - "params": [], - "param_types": {}, - "checksum": cursor._checksum, - } + statement = Statement( + "SELECT 1", + [], + {}, + cursor._checksum, + ) connection._statements.append(statement) with mock.patch( @@ -546,6 +547,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): from google.api_core.exceptions import Aborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement row = ["field1", "field2"] row2 = ["updated_field1", "field2"] @@ -564,12 +566,12 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = { - "sql": "SELECT 1", - "params": [], - "param_types": {}, - "checksum": cursor._checksum, - } + statement = Statement( + "SELECT 1", + [], + {}, + cursor._checksum, + ) connection._statements.append(statement) with mock.patch( From fc890c2d9e933d417963ac91fc8f439146fd9ff1 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 29 Oct 2020 12:16:29 +0300 Subject: [PATCH 13/18] lint fix --- google/cloud/spanner_dbapi/cursor.py | 5 +--- tests/unit/spanner_dbapi/test_connection.py | 28 +++------------------ tests/unit/spanner_dbapi/test_cursor.py | 14 ++--------- 3 files changed, 7 insertions(+), 40 deletions(-) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index f953634a8c..20d241f2d7 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -166,10 +166,7 @@ def execute(self, sql, args=None): sql, params = sql_pyformat_args_to_spanner(sql, args) statement = Statement( - sql, - params, - get_param_types(params), - ResultsChecksum(), + sql, params, get_param_types(params), ResultsChecksum(), ) ( self._result_set, diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 4b388def9f..df19b257e4 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -328,12 +328,7 @@ def test_run_statement(self): connection = self._make_connection() - statement = Statement( - sql, - params, - param_types, - ResultsChecksum(), - ) + statement = Statement(sql, params, param_types, ResultsChecksum(),) with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" ): @@ -394,12 +389,7 @@ def test_retry_transaction(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = Statement( - "SELECT 1", - [], - {}, - checksum, - ) + statement = Statement("SELECT 1", [], {}, checksum,) connection._statements.append(statement) with mock.patch( @@ -432,12 +422,7 @@ def test_retry_transaction_checksum_mismatch(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = Statement( - "SELECT 1", - [], - {}, - checksum, - ) + statement = Statement("SELECT 1", [], {}, checksum,) connection._statements.append(statement) with mock.patch( @@ -469,12 +454,7 @@ def test_commit_retry_aborted_statements(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement( - "SELECT 1", - [], - {}, - cursor._checksum, - ) + statement = Statement("SELECT 1", [], {}, cursor._checksum,) connection._statements.append(statement) connection._transaction = mock.Mock() diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 5f3ca0b63b..6478a73e94 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -521,12 +521,7 @@ def test_fetchone_retry_aborted_statements(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement( - "SELECT 1", - [], - {}, - cursor._checksum, - ) + statement = Statement("SELECT 1", [], {}, cursor._checksum,) connection._statements.append(statement) with mock.patch( @@ -566,12 +561,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement( - "SELECT 1", - [], - {}, - cursor._checksum, - ) + statement = Statement("SELECT 1", [], {}, cursor._checksum,) connection._statements.append(statement) with mock.patch( From d204836703fefdf17c36e9f8e2d9aae6f6ac305c Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 30 Oct 2020 12:35:09 +0300 Subject: [PATCH 14/18] use new transaction on retry --- google/cloud/spanner_dbapi/checksum.py | 6 +-- google/cloud/spanner_dbapi/connection.py | 13 ++++--- google/cloud/spanner_dbapi/cursor.py | 24 ++++++++++-- google/cloud/spanner_dbapi/exceptions.py | 10 +++++ tests/unit/spanner_dbapi/test_checksum.py | 8 ++-- tests/unit/spanner_dbapi/test_connection.py | 18 ++++++++- tests/unit/spanner_dbapi/test_cursor.py | 41 ++++++++++++++++++++- 7 files changed, 101 insertions(+), 19 deletions(-) diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py index 22aaa11548..798518db36 100644 --- a/google/cloud/spanner_dbapi/checksum.py +++ b/google/cloud/spanner_dbapi/checksum.py @@ -9,7 +9,7 @@ import hashlib import pickle -from google.api_core.exceptions import Aborted +from google.cloud.spanner_dbapi.exceptions import AbortedRetried class ResultsChecksum: @@ -64,9 +64,9 @@ def _compare_checksums(original, retried): :type retried: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum` :param retried: results checksum of the retried transaction. - :raises: :exc:`google.api_core.exceptions.Aborted` in case if checksums are not equal. + :raises: :exc:`google.cloud.spanner_dbapi.exceptions.AbortedRetried` in case if checksums are not equal. """ if retried != original: - raise Aborted( + raise AbortedRetried( "The transaction was aborted and could not be retried due to a concurrent modification." ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 5c877e554f..6ef3428ea8 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -115,14 +115,16 @@ def _release_session(self): def retry_transaction(self): """Retry the aborted transaction. - All the statements executed in the transaction - will be re-executed. Results checksums of the original - statements and the retried ones will be compared. + All the statements executed in the original transaction + will be re-executed in new one. Results checksums of the + original statements and the retried ones will be compared. - :raises: :class:`google.api_core.exceptions.Aborted` + :raises: :class:`google.cloud.spanner_dbapi.exceptions.AbortedRetried` If results checksum of the retried statement is not equal to the checksum of the original one. """ + self._transaction = None + for statement in self._statements: res_iter, retried_checksum = self.run_statement( statement, retried=True @@ -135,7 +137,8 @@ def retry_transaction(self): _compare_checksums(statement.checksum, retried_checksum) # executing the failed statement else: - # streaming up to the failed result + # streaming up to the failed result or + # to the end of the streaming iterator while len(retried_checksum) < len(statement.checksum): try: res = next(iter(res_iter)) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 20d241f2d7..5fa0083d30 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -218,7 +218,13 @@ def fetchone(self): except StopIteration: return except Aborted: - self.connection.retry_transaction() + while True: + try: + self.connection.retry_transaction() + break + except Aborted: + pass + return self.fetchone() def fetchall(self): @@ -233,7 +239,13 @@ def fetchall(self): self._checksum.consume_result(row) res.append(row) except Aborted: - self._connection.retry_transaction() + while True: + try: + self._connection.retry_transaction() + break + except Aborted: + pass + return self.fetchall() return res @@ -263,7 +275,13 @@ def fetchmany(self, size=None): except StopIteration: break except Aborted: - self._connection.retry_transaction() + while True: + try: + self._connection.retry_transaction() + break + except Aborted: + pass + return self.fetchmany(size) return items diff --git a/google/cloud/spanner_dbapi/exceptions.py b/google/cloud/spanner_dbapi/exceptions.py index b21be2c949..fbc43b0854 100644 --- a/google/cloud/spanner_dbapi/exceptions.py +++ b/google/cloud/spanner_dbapi/exceptions.py @@ -92,3 +92,13 @@ class NotSupportedError(DatabaseError): """ pass + + +class AbortedRetried(OperationalError): + """ + Error for case of no aborted transaction retry + is available, because of underlying data being + changed during a retry. + """ + + pass diff --git a/tests/unit/spanner_dbapi/test_checksum.py b/tests/unit/spanner_dbapi/test_checksum.py index 486ba0c036..5e34194d64 100644 --- a/tests/unit/spanner_dbapi/test_checksum.py +++ b/tests/unit/spanner_dbapi/test_checksum.py @@ -6,7 +6,7 @@ import unittest -from google.api_core.exceptions import Aborted +from google.cloud.spanner_dbapi.exceptions import AbortedRetried class Test_compare_checksums(unittest.TestCase): @@ -31,7 +31,7 @@ def test_less_results(self): retried = ResultsChecksum() - with self.assertRaises(Aborted): + with self.assertRaises(AbortedRetried): _compare_checksums(original, retried) def test_more_results(self): @@ -45,7 +45,7 @@ def test_more_results(self): retried.consume_result(5) retried.consume_result(2) - with self.assertRaises(Aborted): + with self.assertRaises(AbortedRetried): _compare_checksums(original, retried) def test_mismatch(self): @@ -58,5 +58,5 @@ def test_mismatch(self): retried = ResultsChecksum() retried.consume_result(2) - with self.assertRaises(Aborted): + with self.assertRaises(AbortedRetried): _compare_checksums(original, retried) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index df19b257e4..a6d1a8088c 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -410,7 +410,7 @@ def test_retry_transaction_checksum_mismatch(self): Check retrying an aborted transaction with results checksums mismatch. """ - from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.exceptions import AbortedRetried from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Statement @@ -429,7 +429,7 @@ def test_retry_transaction_checksum_mismatch(self): "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=([retried_row], retried_checkum), ): - with self.assertRaises(Aborted): + with self.assertRaises(AbortedRetried): connection.retry_transaction() def test_commit_retry_aborted_statements(self): @@ -471,3 +471,17 @@ def test_commit_retry_aborted_statements(self): connection.commit() run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_drop_transaction(self): + """ + Check that before retrying an aborted transaction + connection drops the original aborted transaction. + """ + connection = self._make_connection() + transaction_mock = mock.Mock() + connection._transaction = transaction_mock + + # as we didn't set any statements, the method + # will only drop the transaction object + connection.retry_transaction() + self.assertIsNone(connection._transaction) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 6478a73e94..315bbc9f23 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -540,6 +540,7 @@ def test_fetchone_retry_aborted_statements(self): def test_fetchone_retry_aborted_statements_checksums_mismatch(self): """Check transaction retrying with underlying data being changed.""" from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.exceptions import AbortedRetried from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement @@ -566,14 +567,50 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", - side_effect=Aborted("Aborted"), + side_effect=(Aborted("Aborted"), None), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", return_value=([row2], ResultsChecksum()), ) as run_mock: - with self.assertRaises(Aborted): + with self.assertRaises(AbortedRetried): cursor.fetchone() run_mock.assert_called_with(statement, retried=True) + + def test_fetchone_retry_aborted_retry(self): + """ + Check that in case of a retried transaction failed, + the connection will retry it once again. + """ + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch.object( + connection, + "retry_transaction", + side_effect=(Aborted("Aborted"), None), + ) as retry_mock: + + cursor.fetchone() + + retry_mock.assert_has_calls((mock.call(), mock.call())) From 9ea2a01728edbda6c867eaf8d867a42bb0ff21ac Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 30 Oct 2020 12:42:55 +0300 Subject: [PATCH 15/18] fix imports --- tests/unit/spanner_dbapi/test_checksum.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/spanner_dbapi/test_checksum.py b/tests/unit/spanner_dbapi/test_checksum.py index 5e34194d64..38e6d7ea77 100644 --- a/tests/unit/spanner_dbapi/test_checksum.py +++ b/tests/unit/spanner_dbapi/test_checksum.py @@ -6,8 +6,6 @@ import unittest -from google.cloud.spanner_dbapi.exceptions import AbortedRetried - class Test_compare_checksums(unittest.TestCase): def test_equal(self): @@ -25,6 +23,7 @@ def test_equal(self): def test_less_results(self): from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import AbortedRetried original = ResultsChecksum() original.consume_result(5) @@ -37,6 +36,7 @@ def test_less_results(self): def test_more_results(self): from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import AbortedRetried original = ResultsChecksum() original.consume_result(5) @@ -51,6 +51,7 @@ def test_more_results(self): def test_mismatch(self): from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.exceptions import AbortedRetried original = ResultsChecksum() original.consume_result(5) From 7e70d8633e23da02613ab0d1d4f7aae425788b5b Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Fri, 30 Oct 2020 14:13:50 +0300 Subject: [PATCH 16/18] add retrying limit --- google/cloud/spanner_dbapi/connection.py | 23 +++++++++- google/cloud/spanner_dbapi/cursor.py | 24 ++-------- tests/unit/spanner_dbapi/test_connection.py | 50 +++++++++++++++++++++ tests/unit/spanner_dbapi/test_cursor.py | 36 --------------- 4 files changed, 75 insertions(+), 58 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 6ef3428ea8..581ea32fce 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -6,11 +6,13 @@ """DB-API Connection for the Google Cloud Spanner.""" +import time import warnings from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner +from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum @@ -21,6 +23,7 @@ AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" +MAX_INTERNAL_RETRIES = 50 class Connection: @@ -123,8 +126,26 @@ def retry_transaction(self): If results checksum of the retried statement is not equal to the checksum of the original one. """ - self._transaction = None + attempt = 0 + while True: + self._transaction = None + attempt += 1 + if attempt > MAX_INTERNAL_RETRIES: + raise + try: + self._rerun_previous_statements() + break + except Aborted as exc: + delay = _get_retry_delay(exc.errors[0], attempt) + if delay: + time.sleep(delay) + + def _rerun_previous_statements(self): + """ + Helper to run all the remembered statements + from the last transaction. + """ for statement in self._statements: res_iter, retried_checksum = self.run_statement( statement, retried=True diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 5fa0083d30..20d241f2d7 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -218,13 +218,7 @@ def fetchone(self): except StopIteration: return except Aborted: - while True: - try: - self.connection.retry_transaction() - break - except Aborted: - pass - + self.connection.retry_transaction() return self.fetchone() def fetchall(self): @@ -239,13 +233,7 @@ def fetchall(self): self._checksum.consume_result(row) res.append(row) except Aborted: - while True: - try: - self._connection.retry_transaction() - break - except Aborted: - pass - + self._connection.retry_transaction() return self.fetchall() return res @@ -275,13 +263,7 @@ def fetchmany(self, size=None): except StopIteration: break except Aborted: - while True: - try: - self._connection.retry_transaction() - break - except Aborted: - pass - + self._connection.retry_transaction() return self.fetchmany(size) return items diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index a6d1a8088c..e2bac2a66e 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -485,3 +485,53 @@ def test_retry_transaction_drop_transaction(self): # will only drop the transaction object connection.retry_transaction() self.assertIsNone(connection._transaction) + + def test_retry_aborted_retry(self): + """ + Check that in case of a retried transaction failed, + the connection will retry it once again. + """ + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + + with mock.patch.object( + connection, + "run_statement", + side_effect=( + Aborted("Aborted", errors=[metadata_mock]), + ([row], ResultsChecksum()), + ), + ) as retry_mock: + + connection.retry_transaction() + + retry_mock.assert_has_calls( + ( + mock.call(statement, retried=True), + mock.call(statement, retried=True), + ) + ) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 315bbc9f23..edbecf90f3 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -578,39 +578,3 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): cursor.fetchone() run_mock.assert_called_with(statement, retried=True) - - def test_fetchone_retry_aborted_retry(self): - """ - Check that in case of a retried transaction failed, - the connection will retry it once again. - """ - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - - with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.__next__", - side_effect=(Aborted("Aborted"), None), - ): - with mock.patch.object( - connection, - "retry_transaction", - side_effect=(Aborted("Aborted"), None), - ) as retry_mock: - - cursor.fetchone() - - retry_mock.assert_has_calls((mock.call(), mock.call())) From 450b91bd434aa2a1f7152079655e6ad506d7e5b1 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 5 Nov 2020 11:16:35 +0300 Subject: [PATCH 17/18] rename the exception, fix remembering retried statements --- google/cloud/spanner_dbapi/checksum.py | 6 +- google/cloud/spanner_dbapi/connection.py | 5 +- google/cloud/spanner_dbapi/exceptions.py | 2 +- tests/unit/spanner_dbapi/test_checksum.py | 12 +- tests/unit/spanner_dbapi/test_connection.py | 238 ++++++++++++++++++++ tests/unit/spanner_dbapi/test_cursor.py | 4 +- 6 files changed, 253 insertions(+), 14 deletions(-) diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py index 798518db36..3cae7cfb62 100644 --- a/google/cloud/spanner_dbapi/checksum.py +++ b/google/cloud/spanner_dbapi/checksum.py @@ -9,7 +9,7 @@ import hashlib import pickle -from google.cloud.spanner_dbapi.exceptions import AbortedRetried +from google.cloud.spanner_dbapi.exceptions import RetryAborted class ResultsChecksum: @@ -64,9 +64,9 @@ def _compare_checksums(original, retried): :type retried: :class:`~google.cloud.spanner_dbapi.checksum.ResultsChecksum` :param retried: results checksum of the retried transaction. - :raises: :exc:`google.cloud.spanner_dbapi.exceptions.AbortedRetried` in case if checksums are not equal. + :raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal. """ if retried != original: - raise AbortedRetried( + raise RetryAborted( "The transaction was aborted and could not be retried due to a concurrent modification." ) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index a21a63d2cc..6bb1574ce3 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -126,7 +126,7 @@ def retry_transaction(self): will be re-executed in new one. Results checksums of the original statements and the retried ones will be compared. - :raises: :class:`google.cloud.spanner_dbapi.exceptions.AbortedRetried` + :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted` If results checksum of the retried statement is not equal to the checksum of the original one. """ @@ -288,7 +288,8 @@ def run_statement(self, statement, retried=False): checksum of this statement results. """ transaction = self.transaction_checkout() - self._statements.append(statement) + if not retried: + self._statements.append(statement) return ( transaction.execute_sql( diff --git a/google/cloud/spanner_dbapi/exceptions.py b/google/cloud/spanner_dbapi/exceptions.py index fbc43b0854..2b021f6b98 100644 --- a/google/cloud/spanner_dbapi/exceptions.py +++ b/google/cloud/spanner_dbapi/exceptions.py @@ -94,7 +94,7 @@ class NotSupportedError(DatabaseError): pass -class AbortedRetried(OperationalError): +class RetryAborted(OperationalError): """ Error for case of no aborted transaction retry is available, because of underlying data being diff --git a/tests/unit/spanner_dbapi/test_checksum.py b/tests/unit/spanner_dbapi/test_checksum.py index 38e6d7ea77..3e7780bd6e 100644 --- a/tests/unit/spanner_dbapi/test_checksum.py +++ b/tests/unit/spanner_dbapi/test_checksum.py @@ -23,20 +23,20 @@ def test_equal(self): def test_less_results(self): from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.exceptions import AbortedRetried + from google.cloud.spanner_dbapi.exceptions import RetryAborted original = ResultsChecksum() original.consume_result(5) retried = ResultsChecksum() - with self.assertRaises(AbortedRetried): + with self.assertRaises(RetryAborted): _compare_checksums(original, retried) def test_more_results(self): from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.exceptions import AbortedRetried + from google.cloud.spanner_dbapi.exceptions import RetryAborted original = ResultsChecksum() original.consume_result(5) @@ -45,13 +45,13 @@ def test_more_results(self): retried.consume_result(5) retried.consume_result(2) - with self.assertRaises(AbortedRetried): + with self.assertRaises(RetryAborted): _compare_checksums(original, retried) def test_mismatch(self): from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.exceptions import AbortedRetried + from google.cloud.spanner_dbapi.exceptions import RetryAborted original = ResultsChecksum() original.consume_result(5) @@ -59,5 +59,5 @@ def test_mismatch(self): retried = ResultsChecksum() retried.consume_result(2) - with self.assertRaises(AbortedRetried): + with self.assertRaises(RetryAborted): _compare_checksums(original, retried) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 1e892f980f..79415bca55 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -335,3 +335,241 @@ def test_global_pool(self): ) as pool_clear_mock: connection.close() assert not pool_clear_mock.called + + def test_run_statement_remember_statements(self): + """Check that Connection remembers executed statements.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + sql = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + statement = Statement(sql, params, param_types, ResultsChecksum(),) + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement) + + self.assertEqual(connection._statements[0].sql, sql) + self.assertEqual(connection._statements[0].params, params) + self.assertEqual(connection._statements[0].param_types, param_types) + self.assertIsInstance( + connection._statements[0].checksum, ResultsChecksum + ) + + def test_run_statement_dont_remember_retried_statements(self): + """Check that Connection doesn't remember re-executed statements.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + sql = """SELECT 23 FROM table WHERE id = @a1""" + params = {"a1": "value"} + param_types = {"a1": str} + + connection = self._make_connection() + + statement = Statement(sql, params, param_types, ResultsChecksum(),) + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.transaction_checkout" + ): + connection.run_statement(statement, retried=True) + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_commit(self): + """ + Check that all the saved statements are + cleared, when the transaction is commited. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.commit() + + self.assertEqual(len(connection._statements), 0) + + def test_clear_statements_on_rollback(self): + """ + Check that all the saved statements are + cleared, when the transaction is roll backed. + """ + connection = self._make_connection() + connection._transaction = mock.Mock() + connection._statements = [{}, {}] + + self.assertEqual(len(connection._statements), 2) + + with mock.patch( + "google.cloud.spanner_v1.transaction.Transaction.commit" + ): + connection.rollback() + + self.assertEqual(len(connection._statements), 0) + + def test_retry_transaction(self): + """Check retrying an aborted transaction.""" + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = Statement("SELECT 1", [], {}, checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], retried_checkum), + ) as run_mock: + with mock.patch( + "google.cloud.spanner_dbapi.connection._compare_checksums" + ) as compare_mock: + connection.retry_transaction() + + compare_mock.assert_called_with(checksum, retried_checkum) + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_checksum_mismatch(self): + """ + Check retrying an aborted transaction + with results checksums mismatch. + """ + from google.cloud.spanner_dbapi.exceptions import RetryAborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + retried_row = ["field3", "field4"] + connection = self._make_connection() + + checksum = ResultsChecksum() + checksum.consume_result(row) + retried_checkum = ResultsChecksum() + + statement = Statement("SELECT 1", [], {}, checksum,) + connection._statements.append(statement) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([retried_row], retried_checkum), + ): + with self.assertRaises(RetryAborted): + connection.retry_transaction() + + def test_commit_retry_aborted_statements(self): + """Check that retried transaction executing the same statements.""" + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + connection._transaction = mock.Mock() + + with mock.patch.object( + connection._transaction, + "commit", + side_effect=(Aborted("Aborted"), None), + ): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.run_statement", + return_value=([row], ResultsChecksum()), + ) as run_mock: + + connection.commit() + + run_mock.assert_called_with(statement, retried=True) + + def test_retry_transaction_drop_transaction(self): + """ + Check that before retrying an aborted transaction + connection drops the original aborted transaction. + """ + connection = self._make_connection() + transaction_mock = mock.Mock() + connection._transaction = transaction_mock + + # as we didn't set any statements, the method + # will only drop the transaction object + connection.retry_transaction() + self.assertIsNone(connection._transaction) + + def test_retry_aborted_retry(self): + """ + Check that in case of a retried transaction failed, + the connection will retry it once again. + """ + from google.api_core.exceptions import Aborted + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.connection import connect + from google.cloud.spanner_dbapi.cursor import Statement + + row = ["field1", "field2"] + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor._checksum = ResultsChecksum() + cursor._checksum.consume_result(row) + + statement = Statement("SELECT 1", [], {}, cursor._checksum,) + connection._statements.append(statement) + + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + + with mock.patch.object( + connection, + "run_statement", + side_effect=( + Aborted("Aborted", errors=[metadata_mock]), + ([row], ResultsChecksum()), + ), + ) as retry_mock: + + connection.retry_transaction() + + retry_mock.assert_has_calls( + ( + mock.call(statement, retried=True), + mock.call(statement, retried=True), + ) + ) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index edbecf90f3..f7dd712ddd 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -540,7 +540,7 @@ def test_fetchone_retry_aborted_statements(self): def test_fetchone_retry_aborted_statements_checksums_mismatch(self): """Check transaction retrying with underlying data being changed.""" from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.exceptions import AbortedRetried + from google.cloud.spanner_dbapi.exceptions import RetryAborted from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect from google.cloud.spanner_dbapi.cursor import Statement @@ -574,7 +574,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self): return_value=([row2], ResultsChecksum()), ) as run_mock: - with self.assertRaises(AbortedRetried): + with self.assertRaises(RetryAborted): cursor.fetchone() run_mock.assert_called_with(statement, retried=True) From 578eaa27b108f7e288384f8e36daeaec6c8bf891 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Thu, 5 Nov 2020 11:32:41 +0300 Subject: [PATCH 18/18] erase excess while cycle - retry_transaction() is already using its own --- google/cloud/spanner_dbapi/connection.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 6bb1574ce3..5c1be8f724 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -235,13 +235,7 @@ def commit(self): self._release_session() self._statements = [] except Aborted: - while True: - try: - self.retry_transaction() - break - except Aborted: - pass - + self.retry_transaction() self.commit() def rollback(self):