From 8a3d700134a6380c033a879cff0616a648df709b Mon Sep 17 00:00:00 2001 From: larkee <31196561+larkee@users.noreply.github.com> Date: Tue, 5 May 2020 10:57:52 +1200 Subject: [PATCH] feat: add support for retrying aborted partitioned DML statements (#66) * feat: add support for retrying aborted partitioned dml statements * run blacken * use retry settings from config * fix imports from rebase Co-authored-by: larkee --- google/cloud/spanner_v1/database.py | 59 ++++++++++++++++++++--------- tests/unit/test_database.py | 46 +++++++++++++++++++--- 2 files changed, 83 insertions(+), 22 deletions(-) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index a3aa3390c4..e7f6de3724 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -21,8 +21,10 @@ import threading import google.auth.credentials +from google.api_core.retry import if_exception_type from google.protobuf.struct_pb2 import Struct from google.cloud.exceptions import NotFound +from google.api_core.exceptions import Aborted import six # pylint: disable=ungrouped-imports @@ -394,29 +396,36 @@ def execute_partitioned_dml( metadata = _metadata_with_prefix(self.name) - with SessionCheckout(self._pool) as session: + def execute_pdml(): + with SessionCheckout(self._pool) as session: + + txn = api.begin_transaction( + session.name, txn_options, metadata=metadata + ) - txn = api.begin_transaction(session.name, txn_options, metadata=metadata) + txn_selector = TransactionSelector(id=txn.id) + + restart = functools.partial( + api.execute_streaming_sql, + session.name, + dml, + transaction=txn_selector, + params=params_pb, + param_types=param_types, + query_options=query_options, + metadata=metadata, + ) - txn_selector = TransactionSelector(id=txn.id) + iterator = _restart_on_unavailable(restart) - restart = functools.partial( - api.execute_streaming_sql, - session.name, - dml, - transaction=txn_selector, - params=params_pb, - param_types=param_types, - query_options=query_options, - metadata=metadata, - ) + result_set = StreamedResultSet(iterator) + list(result_set) # consume all partials - iterator = _restart_on_unavailable(restart) + return result_set.stats.row_count_lower_bound - result_set = StreamedResultSet(iterator) - list(result_set) # consume all partials + retry_config = api._method_configs["ExecuteStreamingSql"].retry - return result_set.stats.row_count_lower_bound + return _retry_on_aborted(execute_pdml, retry_config)() def session(self, labels=None): """Factory to create a session for this database. @@ -976,3 +985,19 @@ def __init__(self, source_type, backup_info): @classmethod def from_pb(cls, pb): return cls(pb.source_type, pb.backup_info) + + +def _retry_on_aborted(func, retry_config): + """Helper for :meth:`Database.execute_partitioned_dml`. + + Wrap function in a Retry that will retry on Aborted exceptions + with the retry config specified. + + :type func: callable + :param func: the function to be retried on Aborted exceptions + + :type retry_config: Retry + :param retry_config: retry object with the settings to be used + """ + retry = retry_config.with_predicate(if_exception_type(Aborted)) + return retry(func) diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 5b71b08325..d8a581f87b 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -53,6 +53,7 @@ class _BaseTest(unittest.TestCase): SESSION_ID = "session_id" SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID TRANSACTION_ID = b"transaction_id" + RETRY_TRANSACTION_ID = b"transaction_id_retry" BACKUP_ID = "backup_id" BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID @@ -735,8 +736,10 @@ def test_drop_success(self): ) def _execute_partitioned_dml_helper( - self, dml, params=None, param_types=None, query_options=None + self, dml, params=None, param_types=None, query_options=None, retried=False ): + from google.api_core.exceptions import Aborted + from google.api_core.retry import Retry from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1.proto.result_set_pb2 import ( PartialResultSet, @@ -752,6 +755,10 @@ def _execute_partitioned_dml_helper( _merge_query_options, ) + import collections + + MethodConfig = collections.namedtuple("MethodConfig", ["retry"]) + transaction_pb = TransactionPB(id=self.TRANSACTION_ID) stats_pb = ResultSetStats(row_count_lower_bound=2) @@ -765,8 +772,14 @@ def _execute_partitioned_dml_helper( pool.put(session) database = self._make_one(self.DATABASE_ID, instance, pool=pool) api = database._spanner_api = self._make_spanner_api() - api.begin_transaction.return_value = transaction_pb - api.execute_streaming_sql.return_value = iterator + api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())} + if retried: + retry_transaction_pb = TransactionPB(id=self.RETRY_TRANSACTION_ID) + api.begin_transaction.side_effect = [transaction_pb, retry_transaction_pb] + api.execute_streaming_sql.side_effect = [Aborted("test"), iterator] + else: + api.begin_transaction.return_value = transaction_pb + api.execute_streaming_sql.return_value = iterator row_count = database.execute_partitioned_dml( dml, params, param_types, query_options @@ -778,11 +791,15 @@ def _execute_partitioned_dml_helper( partitioned_dml=TransactionOptions.PartitionedDml() ) - api.begin_transaction.assert_called_once_with( + api.begin_transaction.assert_called_with( session.name, txn_options, metadata=[("google-cloud-resource-prefix", database.name)], ) + if retried: + self.assertEqual(api.begin_transaction.call_count, 2) + else: + self.assertEqual(api.begin_transaction.call_count, 1) if params: expected_params = Struct( @@ -798,7 +815,7 @@ def _execute_partitioned_dml_helper( expected_query_options, query_options ) - api.execute_streaming_sql.assert_called_once_with( + api.execute_streaming_sql.assert_any_call( self.SESSION_NAME, dml, transaction=expected_transaction, @@ -807,6 +824,22 @@ def _execute_partitioned_dml_helper( query_options=expected_query_options, metadata=[("google-cloud-resource-prefix", database.name)], ) + if retried: + expected_retry_transaction = TransactionSelector( + id=self.RETRY_TRANSACTION_ID + ) + api.execute_streaming_sql.assert_called_with( + self.SESSION_NAME, + dml, + transaction=expected_retry_transaction, + params=expected_params, + param_types=param_types, + query_options=expected_query_options, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + self.assertEqual(api.execute_streaming_sql.call_count, 2) + else: + self.assertEqual(api.execute_streaming_sql.call_count, 1) def test_execute_partitioned_dml_wo_params(self): self._execute_partitioned_dml_helper(dml=DML_WO_PARAM) @@ -828,6 +861,9 @@ def test_execute_partitioned_dml_w_query_options(self): query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"), ) + def test_execute_partitioned_dml_wo_params_retry_aborted(self): + self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True) + def test_session_factory_defaults(self): from google.cloud.spanner_v1.session import Session