Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for custom timeout and retry parameters in execute_update method in transactions #251

Merged
merged 7 commits into from Mar 11, 2021
20 changes: 18 additions & 2 deletions google/cloud/spanner_v1/transaction.py
Expand Up @@ -29,6 +29,7 @@
from google.cloud.spanner_v1.snapshot import _SnapshotBase
from google.cloud.spanner_v1.batch import _BatchBase
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.api_core import gapic_v1


class Transaction(_SnapshotBase, _BatchBase):
Expand Down Expand Up @@ -185,7 +186,14 @@ def _make_params_pb(params, param_types):
return {}

def execute_update(
self, dml, params=None, param_types=None, query_mode=None, query_options=None
self,
dml,
params=None,
param_types=None,
query_mode=None,
query_options=None,
retry=gapic_v1.method.DEFAULT,
vi3k6i5 marked this conversation as resolved.
Show resolved Hide resolved
timeout=None,
vi3k6i5 marked this conversation as resolved.
Show resolved Hide resolved
):
"""Perform an ``ExecuteSql`` API request with DML.

Expand All @@ -212,6 +220,12 @@ def execute_update(
or :class:`dict`
:param query_options: (Optional) Options that are provided for query plan stability.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) Designation of what errors, if any, should be retried.
vi3k6i5 marked this conversation as resolved.
Show resolved Hide resolved

:type timeout: float
:param timeout: (Optional) The timeout for this request.

:rtype: int
:returns: Count of rows affected by the DML statement.
"""
Expand Down Expand Up @@ -245,7 +259,9 @@ def execute_update(
with trace_call(
"CloudSpanner.ReadWriteTransaction", self._session, trace_attributes
):
response = api.execute_sql(request=request, metadata=metadata)
response = api.execute_sql(
request=request, metadata=metadata, retry=retry, timeout=timeout
)
return response.stats.row_count_exact

def batch_update(self, statements):
Expand Down
18 changes: 17 additions & 1 deletion tests/unit/test_transaction.py
Expand Up @@ -17,6 +17,7 @@
from tests._helpers import OpenTelemetryBase, StatusCanonicalCode
from google.cloud.spanner_v1 import Type
from google.cloud.spanner_v1 import TypeCode
from google.api_core import gapic_v1

TABLE_NAME = "citizens"
COLUMNS = ["email", "first_name", "last_name", "age"]
Expand Down Expand Up @@ -410,7 +411,9 @@ def test_execute_update_w_params_wo_param_types(self):
with self.assertRaises(ValueError):
transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS)

def _execute_update_helper(self, count=0, query_options=None):
def _execute_update_helper(
self, count=0, query_options=None, retry=gapic_v1.method.DEFAULT, timeout=None
):
from google.protobuf.struct_pb2 import Struct
from google.cloud.spanner_v1 import (
ResultSet,
Expand Down Expand Up @@ -439,6 +442,8 @@ def _execute_update_helper(self, count=0, query_options=None):
PARAM_TYPES,
query_mode=MODE,
query_options=query_options,
retry=retry,
timeout=timeout,
)

self.assertEqual(row_count, 1)
Expand Down Expand Up @@ -466,6 +471,8 @@ def _execute_update_helper(self, count=0, query_options=None):
)
api.execute_sql.assert_called_once_with(
request=expected_request,
retry=retry,
timeout=timeout,
metadata=[("google-cloud-resource-prefix", database.name)],
)

Expand All @@ -477,6 +484,15 @@ def test_execute_update_new_transaction(self):
def test_execute_update_w_count(self):
self._execute_update_helper(count=1)

def test_execute_update_w_timeout_param(self):
self._execute_update_helper(timeout=2.0)

def test_execute_update_w_retry_param(self):
self._execute_update_helper(retry=gapic_v1.method.DEFAULT)

def test_execute_update_w_timeout_and_retry_params(self):
self._execute_update_helper(retry=gapic_v1.method.DEFAULT, timeout=2.0)

def test_execute_update_error(self):
database = _Database()
database.spanner_api = self._make_spanner_api()
Expand Down