diff --git a/google/cloud/spanner_v1/__init__.py b/google/cloud/spanner_v1/__init__.py index 7c9e9d70fe..4ece165503 100644 --- a/google/cloud/spanner_v1/__init__.py +++ b/google/cloud/spanner_v1/__init__.py @@ -28,6 +28,7 @@ from .types.query_plan import PlanNode from .types.query_plan import QueryPlan from .types.result_set import PartialResultSet +from .types import RequestOptions from .types.result_set import ResultSet from .types.result_set import ResultSetMetadata from .types.result_set import ResultSetStats @@ -119,6 +120,7 @@ "PlanNode", "QueryPlan", "ReadRequest", + "RequestOptions", "ResultSet", "ResultSetMetadata", "ResultSetStats", diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 9a79507886..d1774ed36d 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -23,6 +23,7 @@ from google.cloud.spanner_v1._helpers import _make_list_value_pbs from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1 import RequestOptions # pylint: enable=ungrouped-imports @@ -138,13 +139,20 @@ def _check_state(self): if self.committed is not None: raise ValueError("Batch already committed") - def commit(self, return_commit_stats=False): + def commit(self, return_commit_stats=False, request_options=None): """Commit mutations to the database. :type return_commit_stats: bool :param return_commit_stats: If true, the response will return commit stats which can be accessed though commit_stats. + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :rtype: datetime :returns: timestamp of the committed changes. """ @@ -154,11 +162,16 @@ def commit(self, return_commit_stats=False): metadata = _metadata_with_prefix(database.name) txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) trace_attributes = {"num_mutations": len(self._mutations)} + + if type(request_options) == dict: + request_options = RequestOptions(request_options) + request = CommitRequest( session=self._session.name, mutations=self._mutations, single_use_transaction=txn_options, return_commit_stats=return_commit_stats, + request_options=request_options, ) with trace_call("CloudSpanner.Commit", self._session, trace_attributes): response = api.commit(request=request, metadata=metadata,) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 5eb688d9c6..fae983f334 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -58,10 +58,10 @@ TransactionOptions, ) from google.cloud.spanner_v1.table import Table +from google.cloud.spanner_v1 import RequestOptions # pylint: enable=ungrouped-imports - SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data" @@ -454,7 +454,12 @@ def drop(self): api.drop_database(database=self.name, metadata=metadata) def execute_partitioned_dml( - self, dml, params=None, param_types=None, query_options=None + self, + dml, + params=None, + param_types=None, + query_options=None, + request_options=None, ): """Execute a partitionable DML statement. @@ -478,12 +483,22 @@ def execute_partitioned_dml( If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.QueryOptions` + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :rtype: int :returns: Count of rows affected by the DML statement. """ query_options = _merge_query_options( self._instance._client._query_options, query_options ) + if type(request_options) == dict: + request_options = RequestOptions(request_options) + if params is not None: from google.cloud.spanner_v1.transaction import Transaction @@ -517,6 +532,7 @@ def execute_pdml(): params=params_pb, param_types=param_types, query_options=query_options, + request_options=request_options, ) method = functools.partial( api.execute_streaming_sql, metadata=metadata, @@ -561,16 +577,23 @@ def snapshot(self, **kw): """ return SnapshotCheckout(self, **kw) - def batch(self): + def batch(self, request_options=None): """Return an object which wraps a batch. The wrapper *must* be used as a context manager, with the batch as the value returned by the wrapper. + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for the commit request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout` :returns: new wrapper """ - return BatchCheckout(self) + return BatchCheckout(self, request_options) def batch_snapshot(self, read_timestamp=None, exact_staleness=None): """Return an object which wraps a batch read / query. @@ -756,11 +779,19 @@ class BatchCheckout(object): :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: database to use + + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for the commit request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. """ - def __init__(self, database): + def __init__(self, database, request_options=None): self._database = database self._session = self._batch = None + self._request_options = request_options def __enter__(self): """Begin ``with`` block.""" @@ -772,7 +803,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" try: if exc_type is None: - self._batch.commit(return_commit_stats=self._database.log_commit_stats) + self._batch.commit( + return_commit_stats=self._database.log_commit_stats, + request_options=self._request_options, + ) finally: if self._database.log_commit_stats and self._batch.commit_stats: self._database.logger.info( diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 1321308ace..84b65429d6 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -230,6 +230,7 @@ def execute_sql( param_types=None, query_mode=None, query_options=None, + request_options=None, retry=google.api_core.gapic_v1.method.DEFAULT, timeout=google.api_core.gapic_v1.method.DEFAULT, ): @@ -258,6 +259,13 @@ def execute_sql( or :class:`dict` :param query_options: (Optional) Options that are provided for query plan stability. + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :type retry: :class:`~google.api_core.retry.Retry` :param retry: (Optional) The retry settings for this request. @@ -273,6 +281,7 @@ def execute_sql( param_types, query_mode, query_options=query_options, + request_options=request_options, retry=retry, timeout=timeout, ) @@ -319,9 +328,12 @@ def run_in_transaction(self, func, *args, **kw): :type kw: dict :param kw: (Optional) keyword arguments to be passed to ``func``. - If passed, "timeout_secs" will be removed and used to + If passed: + "timeout_secs" will be removed and used to override the default retry timeout which defines maximum timestamp to continue retrying the transaction. + "commit_request_options" will be removed and used to set the + request options for the commit request. :rtype: Any :returns: The return value of ``func``. @@ -330,6 +342,7 @@ def run_in_transaction(self, func, *args, **kw): reraises any non-ABORT exceptions raised by ``func``. """ deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS) + commit_request_options = kw.pop("commit_request_options", None) attempts = 0 while True: @@ -355,7 +368,10 @@ def run_in_transaction(self, func, *args, **kw): raise try: - txn.commit(return_commit_stats=self._database.log_commit_stats) + txn.commit( + return_commit_stats=self._database.log_commit_stats, + request_options=commit_request_options, + ) except Aborted as exc: del self._transaction _delay_until_retry(exc, deadline, attempts) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index f926d7836d..eccd8720e1 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -34,6 +34,7 @@ from google.cloud.spanner_v1._helpers import _SessionWrapper from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1.streamed import StreamedResultSet +from google.cloud.spanner_v1 import RequestOptions _STREAM_RESUMPTION_INTERNAL_ERROR_MESSAGES = ( "RST_STREAM", @@ -124,6 +125,7 @@ def read( index="", limit=0, partition=None, + request_options=None, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, @@ -152,6 +154,13 @@ def read( from :meth:`partition_read`. Incompatible with ``limit``. + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :type retry: :class:`~google.api_core.retry.Retry` :param retry: (Optional) The retry settings for this request. @@ -176,6 +185,9 @@ def read( metadata = _metadata_with_prefix(database.name) transaction = self._make_txn_selector() + if type(request_options) == dict: + request_options = RequestOptions(request_options) + request = ReadRequest( session=self._session.name, table=table, @@ -185,6 +197,7 @@ def read( index=index, limit=limit, partition_token=partition, + request_options=request_options, ) restart = functools.partial( api.streaming_read, @@ -217,6 +230,7 @@ def execute_sql( param_types=None, query_mode=None, query_options=None, + request_options=None, partition=None, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, @@ -249,6 +263,13 @@ def execute_sql( If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.QueryOptions` + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :type partition: bytes :param partition: (Optional) one of the partition tokens returned from :meth:`partition_query`. @@ -291,6 +312,9 @@ def execute_sql( default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) + if type(request_options) == dict: + request_options = RequestOptions(request_options) + request = ExecuteSqlRequest( session=self._session.name, sql=sql, @@ -301,6 +325,7 @@ def execute_sql( partition_token=partition, seqno=self._execute_sql_count, query_options=query_options, + request_options=request_options, ) restart = functools.partial( api.execute_streaming_sql, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 4c99b26a09..fce14eb60d 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -29,6 +29,7 @@ from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1 import RequestOptions from google.api_core import gapic_v1 @@ -122,13 +123,20 @@ def rollback(self): self.rolled_back = True del self._session._transaction - def commit(self, return_commit_stats=False): + def commit(self, return_commit_stats=False, request_options=None): """Commit mutations to the database. :type return_commit_stats: bool :param return_commit_stats: If true, the response will return commit stats which can be accessed though commit_stats. + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :rtype: datetime :returns: timestamp of the committed changes. :raises ValueError: if there are no mutations to commit. @@ -139,11 +147,16 @@ def commit(self, return_commit_stats=False): api = database.spanner_api metadata = _metadata_with_prefix(database.name) trace_attributes = {"num_mutations": len(self._mutations)} + + if type(request_options) == dict: + request_options = RequestOptions(request_options) + request = CommitRequest( session=self._session.name, mutations=self._mutations, transaction_id=self._transaction_id, return_commit_stats=return_commit_stats, + request_options=request_options, ) with trace_call("CloudSpanner.Commit", self._session, trace_attributes): response = api.commit(request=request, metadata=metadata,) @@ -192,6 +205,7 @@ def execute_update( param_types=None, query_mode=None, query_options=None, + request_options=None, *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, @@ -221,6 +235,13 @@ def execute_update( or :class:`dict` :param query_options: (Optional) Options that are provided for query plan stability. + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :type retry: :class:`~google.api_core.retry.Retry` :param retry: (Optional) The retry settings for this request. @@ -246,7 +267,11 @@ def execute_update( default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) + if type(request_options) == dict: + request_options = RequestOptions(request_options) + trace_attributes = {"db.statement": dml} + request = ExecuteSqlRequest( session=self._session.name, sql=dml, @@ -256,6 +281,7 @@ def execute_update( query_mode=query_mode, query_options=query_options, seqno=seqno, + request_options=request_options, ) with trace_call( "CloudSpanner.ReadWriteTransaction", self._session, trace_attributes @@ -265,7 +291,7 @@ def execute_update( ) return response.stats.row_count_exact - def batch_update(self, statements): + def batch_update(self, statements, request_options=None): """Perform a batch of DML statements via an ``ExecuteBatchDml`` request. :type statements: @@ -279,6 +305,13 @@ def batch_update(self, statements): must also be passed, as a dict mapping names to the type of value passed in 'params'. + :type request_options: + :class:`google.cloud.spanner_v1.types.RequestOptions` + :param request_options: + (Optional) Common options for this request. + If a dict is provided, it must be of the same form as the protobuf + message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + :rtype: Tuple(status, Sequence[int]) :returns: @@ -310,6 +343,9 @@ def batch_update(self, statements): self._execute_sql_count + 1, ) + if type(request_options) == dict: + request_options = RequestOptions(request_options) + trace_attributes = { # Get just the queries from the DML statement batch "db.statement": ";".join([statement.sql for statement in parsed]) @@ -319,6 +355,7 @@ def batch_update(self, statements): transaction=transaction, statements=parsed, seqno=seqno, + request_options=request_options, ) with trace_call("CloudSpanner.DMLTransaction", self._session, trace_attributes): response = api.execute_batch_dml(request=request, metadata=metadata) diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 7c1c0d6f64..8471cfc4c2 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -43,11 +43,13 @@ from google.cloud.spanner_v1.instance import Backup from google.cloud.spanner_v1.instance import Instance from google.cloud.spanner_v1.table import Table +from google.cloud.spanner_v1 import RequestOptions from test_utils.retry import RetryErrors from test_utils.retry import RetryInstanceState from test_utils.retry import RetryResult from test_utils.system import unique_resource_id + from tests._fixtures import DDL_STATEMENTS from tests._fixtures import EMULATOR_DDL_STATEMENTS from tests._helpers import OpenTelemetryBase, HAS_OPENTELEMETRY_INSTALLED @@ -1821,6 +1823,9 @@ def _setup_table(txn): update_statement, params={"email": nonesuch, "target": target}, param_types={"email": param_types.STRING, "target": param_types.STRING}, + request_options=RequestOptions( + priority=RequestOptions.Priority.PRIORITY_MEDIUM + ), ) self.assertEqual(row_count, 1) diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 3112f17ecf..f7915814a3 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -232,12 +232,13 @@ def test_commit_ok(self): self.assertEqual(committed, now) self.assertEqual(batch.committed, committed) - (session, mutations, single_use_txn, metadata) = api._committed + (session, mutations, single_use_txn, metadata, request_options) = api._committed self.assertEqual(session, self.SESSION_NAME) self.assertEqual(mutations, batch._mutations) self.assertIsInstance(single_use_txn, TransactionOptions) self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write")) self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)]) + self.assertEqual(request_options, None) self.assertSpanAttributes( "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) @@ -280,12 +281,13 @@ def test_context_mgr_success(self): self.assertEqual(batch.committed, now) - (session, mutations, single_use_txn, metadata) = api._committed + (session, mutations, single_use_txn, metadata, request_options) = api._committed self.assertEqual(session, self.SESSION_NAME) self.assertEqual(mutations, batch._mutations) self.assertIsInstance(single_use_txn, TransactionOptions) self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write")) self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)]) + self.assertEqual(request_options, None) self.assertSpanAttributes( "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) @@ -339,7 +341,7 @@ def __init__(self, **kwargs): self.__dict__.update(**kwargs) def commit( - self, request=None, metadata=None, + self, request=None, metadata=None, request_options=None, ): from google.api_core.exceptions import Unknown @@ -349,6 +351,7 @@ def commit( request.mutations, request.single_use_transaction, metadata, + request_options, ) if self._rpc_error: raise Unknown("error") diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index c71bab2581..05e6f2b422 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -21,6 +21,8 @@ from google.cloud.spanner_v1.param_types import INT64 from google.api_core.retry import Retry +from google.cloud.spanner_v1 import RequestOptions + DML_WO_PARAM = """ DELETE FROM citizens """ @@ -902,7 +904,13 @@ def test_drop_success(self): ) def _execute_partitioned_dml_helper( - self, dml, params=None, param_types=None, query_options=None, retried=False + self, + dml, + params=None, + param_types=None, + query_options=None, + request_options=None, + retried=False, ): from google.api_core.exceptions import Aborted from google.api_core.retry import Retry @@ -949,7 +957,7 @@ def _execute_partitioned_dml_helper( api.execute_streaming_sql.return_value = iterator row_count = database.execute_partitioned_dml( - dml, params, param_types, query_options + dml, params, param_types, query_options, request_options ) self.assertEqual(row_count, 2) @@ -989,6 +997,7 @@ def _execute_partitioned_dml_helper( params=expected_params, param_types=param_types, query_options=expected_query_options, + request_options=request_options, ) api.execute_streaming_sql.assert_any_call( @@ -1006,6 +1015,7 @@ def _execute_partitioned_dml_helper( params=expected_params, param_types=param_types, query_options=expected_query_options, + request_options=request_options, ) api.execute_streaming_sql.assert_called_with( request=expected_request, @@ -1035,6 +1045,14 @@ def test_execute_partitioned_dml_w_query_options(self): query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"), ) + def test_execute_partitioned_dml_w_request_options(self): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, + request_options=RequestOptions( + priority=RequestOptions.Priority.PRIORITY_MEDIUM + ), + ) + def test_execute_partitioned_dml_wo_params_retry_aborted(self): self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 9c2e9dce3c..4daabdf952 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -550,6 +550,7 @@ def test_execute_sql_defaults(self): None, None, query_options=None, + request_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, ) @@ -579,6 +580,7 @@ def test_execute_sql_non_default_retry(self): param_types, "PLAN", query_options=None, + request_options=None, timeout=None, retry=None, ) @@ -606,6 +608,7 @@ def test_execute_sql_explicit(self): param_types, "PLAN", query_options=None, + request_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index bbc1753474..627b18d910 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -15,6 +15,8 @@ from google.api_core import gapic_v1 import mock + +from google.cloud.spanner_v1 import RequestOptions from tests._helpers import ( OpenTelemetryBase, StatusCode, @@ -590,6 +592,7 @@ def _execute_sql_helper( partition=None, sql_count=0, query_options=None, + request_options=None, timeout=gapic_v1.method.DEFAULT, retry=gapic_v1.method.DEFAULT, ): @@ -649,6 +652,7 @@ def _execute_sql_helper( PARAM_TYPES, query_mode=MODE, query_options=query_options, + request_options=request_options, partition=partition, retry=retry, timeout=timeout, @@ -695,6 +699,7 @@ def _execute_sql_helper( param_types=PARAM_TYPES, query_mode=MODE, query_options=expected_query_options, + request_options=request_options, partition_token=partition, seqno=sql_count, ) @@ -747,6 +752,14 @@ def test_execute_sql_w_query_options(self): query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"), ) + def test_execute_sql_w_request_options(self): + self._execute_sql_helper( + multi_use=False, + request_options=RequestOptions( + priority=RequestOptions.Priority.PRIORITY_MEDIUM + ), + ) + def _partition_read_helper( self, multi_use, diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 99f986d99e..d87821fa4a 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -14,12 +14,15 @@ import mock -from tests._helpers import OpenTelemetryBase, StatusCode + +from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1 import Type from google.cloud.spanner_v1 import TypeCode from google.api_core.retry import Retry from google.api_core import gapic_v1 +from tests._helpers import OpenTelemetryBase, StatusCode + TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] VALUES = [ @@ -416,6 +419,7 @@ def _execute_update_helper( self, count=0, query_options=None, + request_options=None, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, ): @@ -447,6 +451,7 @@ def _execute_update_helper( PARAM_TYPES, query_mode=MODE, query_options=query_options, + request_options=request_options, retry=retry, timeout=timeout, ) @@ -472,6 +477,7 @@ def _execute_update_helper( param_types=PARAM_TYPES, query_mode=MODE, query_options=expected_query_options, + request_options=request_options, seqno=count, ) api.execute_sql.assert_called_once_with( @@ -518,6 +524,13 @@ def test_execute_update_w_query_options(self): query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3") ) + def test_execute_update_w_request_options(self): + self._execute_update_helper( + request_options=RequestOptions( + priority=RequestOptions.Priority.PRIORITY_MEDIUM + ) + ) + def test_batch_update_other_error(self): database = _Database() database.spanner_api = self._make_spanner_api() @@ -529,7 +542,7 @@ def test_batch_update_other_error(self): with self.assertRaises(RuntimeError): transaction.batch_update(statements=[DML_QUERY]) - def _batch_update_helper(self, error_after=None, count=0): + def _batch_update_helper(self, error_after=None, count=0, request_options=None): from google.rpc.status_pb2 import Status from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import param_types @@ -576,7 +589,9 @@ def _batch_update_helper(self, error_after=None, count=0): transaction._transaction_id = self.TRANSACTION_ID transaction._execute_sql_count = count - status, row_counts = transaction.batch_update(dml_statements) + status, row_counts = transaction.batch_update( + dml_statements, request_options=request_options + ) self.assertEqual(status, expected_status) self.assertEqual(row_counts, expected_row_counts) @@ -602,6 +617,7 @@ def _batch_update_helper(self, error_after=None, count=0): transaction=expected_transaction, statements=expected_statements, seqno=count, + request_options=request_options, ) api.execute_batch_dml.assert_called_once_with( request=expected_request, @@ -611,7 +627,11 @@ def _batch_update_helper(self, error_after=None, count=0): self.assertEqual(transaction._execute_sql_count, count + 1) def test_batch_update_wo_errors(self): - self._batch_update_helper() + self._batch_update_helper( + request_options=RequestOptions( + priority=RequestOptions.Priority.PRIORITY_MEDIUM + ), + ) def test_batch_update_w_errors(self): self._batch_update_helper(error_after=2, count=1)