Skip to content

Commit

Permalink
feat: add support for retrying aborted partitioned DML statements (#66)
Browse files Browse the repository at this point in the history
* feat: add support for retrying aborted partitioned dml statements

* run blacken

* use retry settings from config

* fix imports from rebase

Co-authored-by: larkee <larkee@users.noreply.github.com>
  • Loading branch information
larkee and larkee committed May 4, 2020
1 parent df4be7f commit 8a3d700
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 22 deletions.
59 changes: 42 additions & 17 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -21,8 +21,10 @@
import threading

import google.auth.credentials
from google.api_core.retry import if_exception_type
from google.protobuf.struct_pb2 import Struct
from google.cloud.exceptions import NotFound
from google.api_core.exceptions import Aborted
import six

# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -394,29 +396,36 @@ def execute_partitioned_dml(

metadata = _metadata_with_prefix(self.name)

with SessionCheckout(self._pool) as session:
def execute_pdml():
with SessionCheckout(self._pool) as session:

txn = api.begin_transaction(
session.name, txn_options, metadata=metadata
)

txn = api.begin_transaction(session.name, txn_options, metadata=metadata)
txn_selector = TransactionSelector(id=txn.id)

restart = functools.partial(
api.execute_streaming_sql,
session.name,
dml,
transaction=txn_selector,
params=params_pb,
param_types=param_types,
query_options=query_options,
metadata=metadata,
)

txn_selector = TransactionSelector(id=txn.id)
iterator = _restart_on_unavailable(restart)

restart = functools.partial(
api.execute_streaming_sql,
session.name,
dml,
transaction=txn_selector,
params=params_pb,
param_types=param_types,
query_options=query_options,
metadata=metadata,
)
result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials

iterator = _restart_on_unavailable(restart)
return result_set.stats.row_count_lower_bound

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
retry_config = api._method_configs["ExecuteStreamingSql"].retry

return result_set.stats.row_count_lower_bound
return _retry_on_aborted(execute_pdml, retry_config)()

def session(self, labels=None):
"""Factory to create a session for this database.
Expand Down Expand Up @@ -976,3 +985,19 @@ def __init__(self, source_type, backup_info):
@classmethod
def from_pb(cls, pb):
return cls(pb.source_type, pb.backup_info)


def _retry_on_aborted(func, retry_config):
"""Helper for :meth:`Database.execute_partitioned_dml`.
Wrap function in a Retry that will retry on Aborted exceptions
with the retry config specified.
:type func: callable
:param func: the function to be retried on Aborted exceptions
:type retry_config: Retry
:param retry_config: retry object with the settings to be used
"""
retry = retry_config.with_predicate(if_exception_type(Aborted))
return retry(func)
46 changes: 41 additions & 5 deletions tests/unit/test_database.py
Expand Up @@ -53,6 +53,7 @@ class _BaseTest(unittest.TestCase):
SESSION_ID = "session_id"
SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID
TRANSACTION_ID = b"transaction_id"
RETRY_TRANSACTION_ID = b"transaction_id_retry"
BACKUP_ID = "backup_id"
BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID

Expand Down Expand Up @@ -735,8 +736,10 @@ def test_drop_success(self):
)

def _execute_partitioned_dml_helper(
self, dml, params=None, param_types=None, query_options=None
self, dml, params=None, param_types=None, query_options=None, retried=False
):
from google.api_core.exceptions import Aborted
from google.api_core.retry import Retry
from google.protobuf.struct_pb2 import Struct
from google.cloud.spanner_v1.proto.result_set_pb2 import (
PartialResultSet,
Expand All @@ -752,6 +755,10 @@ def _execute_partitioned_dml_helper(
_merge_query_options,
)

import collections

MethodConfig = collections.namedtuple("MethodConfig", ["retry"])

transaction_pb = TransactionPB(id=self.TRANSACTION_ID)

stats_pb = ResultSetStats(row_count_lower_bound=2)
Expand All @@ -765,8 +772,14 @@ def _execute_partitioned_dml_helper(
pool.put(session)
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
api = database._spanner_api = self._make_spanner_api()
api.begin_transaction.return_value = transaction_pb
api.execute_streaming_sql.return_value = iterator
api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())}
if retried:
retry_transaction_pb = TransactionPB(id=self.RETRY_TRANSACTION_ID)
api.begin_transaction.side_effect = [transaction_pb, retry_transaction_pb]
api.execute_streaming_sql.side_effect = [Aborted("test"), iterator]
else:
api.begin_transaction.return_value = transaction_pb
api.execute_streaming_sql.return_value = iterator

row_count = database.execute_partitioned_dml(
dml, params, param_types, query_options
Expand All @@ -778,11 +791,15 @@ def _execute_partitioned_dml_helper(
partitioned_dml=TransactionOptions.PartitionedDml()
)

api.begin_transaction.assert_called_once_with(
api.begin_transaction.assert_called_with(
session.name,
txn_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
if retried:
self.assertEqual(api.begin_transaction.call_count, 2)
else:
self.assertEqual(api.begin_transaction.call_count, 1)

if params:
expected_params = Struct(
Expand All @@ -798,7 +815,7 @@ def _execute_partitioned_dml_helper(
expected_query_options, query_options
)

api.execute_streaming_sql.assert_called_once_with(
api.execute_streaming_sql.assert_any_call(
self.SESSION_NAME,
dml,
transaction=expected_transaction,
Expand All @@ -807,6 +824,22 @@ def _execute_partitioned_dml_helper(
query_options=expected_query_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
if retried:
expected_retry_transaction = TransactionSelector(
id=self.RETRY_TRANSACTION_ID
)
api.execute_streaming_sql.assert_called_with(
self.SESSION_NAME,
dml,
transaction=expected_retry_transaction,
params=expected_params,
param_types=param_types,
query_options=expected_query_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
self.assertEqual(api.execute_streaming_sql.call_count, 2)
else:
self.assertEqual(api.execute_streaming_sql.call_count, 1)

def test_execute_partitioned_dml_wo_params(self):
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM)
Expand All @@ -828,6 +861,9 @@ def test_execute_partitioned_dml_w_query_options(self):
query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"),
)

def test_execute_partitioned_dml_wo_params_retry_aborted(self):
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True)

def test_session_factory_defaults(self):
from google.cloud.spanner_v1.session import Session

Expand Down

0 comments on commit 8a3d700

Please sign in to comment.