Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for retrying aborted partitioned DML statements #66

Merged
merged 4 commits into from May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 42 additions & 17 deletions google/cloud/spanner_v1/database.py
Expand Up @@ -21,8 +21,10 @@
import threading

import google.auth.credentials
from google.api_core.retry import if_exception_type
from google.protobuf.struct_pb2 import Struct
from google.cloud.exceptions import NotFound
from google.api_core.exceptions import Aborted
import six

# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -394,29 +396,36 @@ def execute_partitioned_dml(

metadata = _metadata_with_prefix(self.name)

with SessionCheckout(self._pool) as session:
def execute_pdml():
with SessionCheckout(self._pool) as session:

txn = api.begin_transaction(
session.name, txn_options, metadata=metadata
)

txn = api.begin_transaction(session.name, txn_options, metadata=metadata)
txn_selector = TransactionSelector(id=txn.id)

restart = functools.partial(
api.execute_streaming_sql,
session.name,
dml,
transaction=txn_selector,
params=params_pb,
param_types=param_types,
query_options=query_options,
metadata=metadata,
)

txn_selector = TransactionSelector(id=txn.id)
iterator = _restart_on_unavailable(restart)

restart = functools.partial(
api.execute_streaming_sql,
session.name,
dml,
transaction=txn_selector,
params=params_pb,
param_types=param_types,
query_options=query_options,
metadata=metadata,
)
result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials

iterator = _restart_on_unavailable(restart)
return result_set.stats.row_count_lower_bound

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials
retry_config = api._method_configs["ExecuteStreamingSql"].retry

return result_set.stats.row_count_lower_bound
return _retry_on_aborted(execute_pdml, retry_config)()

def session(self, labels=None):
"""Factory to create a session for this database.
Expand Down Expand Up @@ -976,3 +985,19 @@ def __init__(self, source_type, backup_info):
@classmethod
def from_pb(cls, pb):
return cls(pb.source_type, pb.backup_info)


def _retry_on_aborted(func, retry_config):
"""Helper for :meth:`Database.execute_partitioned_dml`.

Wrap function in a Retry that will retry on Aborted exceptions
with the retry config specified.

:type func: callable
:param func: the function to be retried on Aborted exceptions

:type retry_config: Retry
:param retry_config: retry object with the settings to be used
"""
retry = retry_config.with_predicate(if_exception_type(Aborted))
return retry(func)
46 changes: 41 additions & 5 deletions tests/unit/test_database.py
Expand Up @@ -53,6 +53,7 @@ class _BaseTest(unittest.TestCase):
SESSION_ID = "session_id"
SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID
TRANSACTION_ID = b"transaction_id"
RETRY_TRANSACTION_ID = b"transaction_id_retry"
BACKUP_ID = "backup_id"
BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID

Expand Down Expand Up @@ -735,8 +736,10 @@ def test_drop_success(self):
)

def _execute_partitioned_dml_helper(
self, dml, params=None, param_types=None, query_options=None
self, dml, params=None, param_types=None, query_options=None, retried=False
):
from google.api_core.exceptions import Aborted
from google.api_core.retry import Retry
from google.protobuf.struct_pb2 import Struct
from google.cloud.spanner_v1.proto.result_set_pb2 import (
PartialResultSet,
Expand All @@ -752,6 +755,10 @@ def _execute_partitioned_dml_helper(
_merge_query_options,
)

import collections

MethodConfig = collections.namedtuple("MethodConfig", ["retry"])

transaction_pb = TransactionPB(id=self.TRANSACTION_ID)

stats_pb = ResultSetStats(row_count_lower_bound=2)
Expand All @@ -765,8 +772,14 @@ def _execute_partitioned_dml_helper(
pool.put(session)
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
api = database._spanner_api = self._make_spanner_api()
api.begin_transaction.return_value = transaction_pb
api.execute_streaming_sql.return_value = iterator
api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())}
if retried:
retry_transaction_pb = TransactionPB(id=self.RETRY_TRANSACTION_ID)
api.begin_transaction.side_effect = [transaction_pb, retry_transaction_pb]
api.execute_streaming_sql.side_effect = [Aborted("test"), iterator]
else:
api.begin_transaction.return_value = transaction_pb
api.execute_streaming_sql.return_value = iterator

row_count = database.execute_partitioned_dml(
dml, params, param_types, query_options
Expand All @@ -778,11 +791,15 @@ def _execute_partitioned_dml_helper(
partitioned_dml=TransactionOptions.PartitionedDml()
)

api.begin_transaction.assert_called_once_with(
api.begin_transaction.assert_called_with(
session.name,
txn_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
if retried:
self.assertEqual(api.begin_transaction.call_count, 2)
else:
self.assertEqual(api.begin_transaction.call_count, 1)

if params:
expected_params = Struct(
Expand All @@ -798,7 +815,7 @@ def _execute_partitioned_dml_helper(
expected_query_options, query_options
)

api.execute_streaming_sql.assert_called_once_with(
api.execute_streaming_sql.assert_any_call(
self.SESSION_NAME,
dml,
transaction=expected_transaction,
Expand All @@ -807,6 +824,22 @@ def _execute_partitioned_dml_helper(
query_options=expected_query_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
if retried:
expected_retry_transaction = TransactionSelector(
id=self.RETRY_TRANSACTION_ID
)
api.execute_streaming_sql.assert_called_with(
self.SESSION_NAME,
dml,
transaction=expected_retry_transaction,
params=expected_params,
param_types=param_types,
query_options=expected_query_options,
metadata=[("google-cloud-resource-prefix", database.name)],
)
self.assertEqual(api.execute_streaming_sql.call_count, 2)
else:
self.assertEqual(api.execute_streaming_sql.call_count, 1)

def test_execute_partitioned_dml_wo_params(self):
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM)
Expand All @@ -828,6 +861,9 @@ def test_execute_partitioned_dml_w_query_options(self):
query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"),
)

def test_execute_partitioned_dml_wo_params_retry_aborted(self):
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True)

def test_session_factory_defaults(self):
from google.cloud.spanner_v1.session import Session

Expand Down