diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index aa2353206f..9099d48c46 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -29,6 +29,7 @@ from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.api_core import gapic_v1 class Transaction(_SnapshotBase, _BatchBase): @@ -185,7 +186,15 @@ def _make_params_pb(params, param_types): return {} def execute_update( - self, dml, params=None, param_types=None, query_mode=None, query_options=None + self, + dml, + params=None, + param_types=None, + query_mode=None, + query_options=None, + *, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, ): """Perform an ``ExecuteSql`` API request with DML. @@ -212,6 +221,12 @@ def execute_update( or :class:`dict` :param query_options: (Optional) Options that are provided for query plan stability. + :type retry: :class:`~google.api_core.retry.Retry` + :param retry: (Optional) The retry settings for this request. + + :type timeout: float + :param timeout: (Optional) The timeout for this request. + :rtype: int :returns: Count of rows affected by the DML statement. """ @@ -245,7 +260,9 @@ def execute_update( with trace_call( "CloudSpanner.ReadWriteTransaction", self._session, trace_attributes ): - response = api.execute_sql(request=request, metadata=metadata) + response = api.execute_sql( + request=request, metadata=metadata, retry=retry, timeout=timeout + ) return response.stats.row_count_exact def batch_update(self, statements): diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 4dc56bfa06..3302f68d2d 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -17,6 +17,7 @@ from tests._helpers import OpenTelemetryBase, StatusCanonicalCode from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode +from google.api_core import gapic_v1 TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -410,7 +411,13 @@ def test_execute_update_w_params_wo_param_types(self): with self.assertRaises(ValueError): transaction.execute_update(DML_QUERY_WITH_PARAM, PARAMS) - def _execute_update_helper(self, count=0, query_options=None): + def _execute_update_helper( + self, + count=0, + query_options=None, + retry=gapic_v1.method.DEFAULT, + timeout=gapic_v1.method.DEFAULT, + ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( ResultSet, @@ -439,6 +446,8 @@ def _execute_update_helper(self, count=0, query_options=None): PARAM_TYPES, query_mode=MODE, query_options=query_options, + retry=retry, + timeout=timeout, ) self.assertEqual(row_count, 1) @@ -466,6 +475,8 @@ def _execute_update_helper(self, count=0, query_options=None): ) api.execute_sql.assert_called_once_with( request=expected_request, + retry=retry, + timeout=timeout, metadata=[("google-cloud-resource-prefix", database.name)], ) @@ -477,6 +488,15 @@ def test_execute_update_new_transaction(self): def test_execute_update_w_count(self): self._execute_update_helper(count=1) + def test_execute_update_w_timeout_param(self): + self._execute_update_helper(timeout=2.0) + + def test_execute_update_w_retry_param(self): + self._execute_update_helper(retry=gapic_v1.method.DEFAULT) + + def test_execute_update_w_timeout_and_retry_params(self): + self._execute_update_helper(retry=gapic_v1.method.DEFAULT, timeout=2.0) + def test_execute_update_error(self): database = _Database() database.spanner_api = self._make_spanner_api()