diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 2172d9d051..4d8364df1f 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -32,6 +32,9 @@ class _BatchBase(_SessionWrapper): :param session: the session used to perform the commit """ + transaction_tag = None + _read_only = False + def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] @@ -118,8 +121,7 @@ def delete(self, table, keyset): class Batch(_BatchBase): - """Accumulate mutations for transmission during :meth:`commit`. - """ + """Accumulate mutations for transmission during :meth:`commit`.""" committed = None commit_stats = None @@ -160,8 +162,14 @@ def commit(self, return_commit_stats=False, request_options=None): txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) trace_attributes = {"num_mutations": len(self._mutations)} - if type(request_options) == dict: + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag + + # Request tags are not supported for commit requests. + request_options.request_tag = None request = CommitRequest( session=self._session.name, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index bcd446ee96..0ba657cba0 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -494,6 +494,8 @@ def execute_partitioned_dml( (Optional) Common options for this request. If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + Please note, the `transactionTag` setting will be ignored as it is + not supported for partitioned DML. :rtype: int :returns: Count of rows affected by the DML statement. @@ -501,8 +503,11 @@ def execute_partitioned_dml( query_options = _merge_query_options( self._instance._client._query_options, query_options ) - if type(request_options) == dict: + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: request_options = RequestOptions(request_options) + request_options.transaction_tag = None if params is not None: from google.cloud.spanner_v1.transaction import Transaction @@ -796,12 +801,19 @@ class BatchCheckout(object): def __init__(self, database, request_options=None): self._database = database self._session = self._batch = None - self._request_options = request_options + if request_options is None: + self._request_options = RequestOptions() + elif type(request_options) == dict: + self._request_options = RequestOptions(request_options) + else: + self._request_options = request_options def __enter__(self): """Begin ``with`` block.""" session = self._session = self._database._pool.get() batch = self._batch = Batch(session) + if self._request_options.transaction_tag: + batch.transaction_tag = self._request_options.transaction_tag return batch def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 4222ca0d5e..5eca0a8d2f 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -340,11 +340,13 @@ def run_in_transaction(self, func, *args, **kw): """ deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS) commit_request_options = kw.pop("commit_request_options", None) + transaction_tag = kw.pop("transaction_tag", None) attempts = 0 while True: if self._transaction is None: txn = self.transaction() + txn.transaction_tag = transaction_tag else: txn = self._transaction if txn._transaction_id is None: diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index c973326496..aaf9caa2fc 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -102,6 +102,7 @@ class _SnapshotBase(_SessionWrapper): """ _multi_use = False + _read_only = True _transaction_id = None _read_request_count = 0 _execute_sql_count = 0 @@ -160,6 +161,8 @@ def read( (Optional) Common options for this request. If a dict is provided, it must be of the same form as the protobuf message :class:`~google.cloud.spanner_v1.types.RequestOptions`. + Please note, the `transactionTag` setting will be ignored for + snapshot as it's not supported for read-only transactions. :type retry: :class:`~google.api_core.retry.Retry` :param retry: (Optional) The retry settings for this request. @@ -185,9 +188,17 @@ def read( metadata = _metadata_with_prefix(database.name) transaction = self._make_txn_selector() - if type(request_options) == dict: + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: request_options = RequestOptions(request_options) + if self._read_only: + # Transaction tags are not supported for read only transactions. + request_options.transaction_tag = None + else: + request_options.transaction_tag = self.transaction_tag + request = ReadRequest( session=self._session.name, table=table, @@ -312,8 +323,15 @@ def execute_sql( default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) - if type(request_options) == dict: + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: request_options = RequestOptions(request_options) + if self._read_only: + # Transaction tags are not supported for read only transactions. + request_options.transaction_tag = None + else: + request_options.transaction_tag = self.transaction_tag request = ExecuteSqlRequest( session=self._session.name, diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index fce14eb60d..b960761147 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -148,8 +148,15 @@ def commit(self, return_commit_stats=False, request_options=None): metadata = _metadata_with_prefix(database.name) trace_attributes = {"num_mutations": len(self._mutations)} - if type(request_options) == dict: + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: request_options = RequestOptions(request_options) + if self.transaction_tag is not None: + request_options.transaction_tag = self.transaction_tag + + # Request tags are not supported for commit requests. + request_options.request_tag = None request = CommitRequest( session=self._session.name, @@ -267,8 +274,11 @@ def execute_update( default_query_options = database._instance._client._query_options query_options = _merge_query_options(default_query_options, query_options) - if type(request_options) == dict: + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag trace_attributes = {"db.statement": dml} @@ -343,8 +353,11 @@ def batch_update(self, statements, request_options=None): self._execute_sql_count + 1, ) - if type(request_options) == dict: + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: request_options = RequestOptions(request_options) + request_options.transaction_tag = self.transaction_tag trace_attributes = { # Get just the queries from the DML statement batch diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index f7915814a3..d6af07ce7e 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -15,6 +15,7 @@ import unittest from tests._helpers import OpenTelemetryBase, StatusCode +from google.cloud.spanner_v1 import RequestOptions TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -39,6 +40,7 @@ class _BaseTest(unittest.TestCase): DATABASE_NAME = INSTANCE_NAME + "/databases/" + DATABASE_ID SESSION_ID = "session-id" SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID + TRANSACTION_TAG = "transaction-tag" def _make_one(self, *args, **kwargs): return self._getTargetClass()(*args, **kwargs) @@ -232,18 +234,87 @@ def test_commit_ok(self): self.assertEqual(committed, now) self.assertEqual(batch.committed, committed) - (session, mutations, single_use_txn, metadata, request_options) = api._committed + (session, mutations, single_use_txn, request_options, metadata) = api._committed self.assertEqual(session, self.SESSION_NAME) self.assertEqual(mutations, batch._mutations) self.assertIsInstance(single_use_txn, TransactionOptions) self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write")) self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)]) - self.assertEqual(request_options, None) + self.assertEqual(request_options, RequestOptions()) self.assertSpanAttributes( "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) ) + def _test_commit_with_request_options(self, request_options=None): + import datetime + from google.cloud.spanner_v1 import CommitResponse + from google.cloud.spanner_v1 import TransactionOptions + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + response = CommitResponse(commit_timestamp=now_pb) + database = _Database() + api = database.spanner_api = _FauxSpannerAPI(_commit_response=response) + session = _Session(database) + batch = self._make_one(session) + batch.transaction_tag = self.TRANSACTION_TAG + batch.insert(TABLE_NAME, COLUMNS, VALUES) + committed = batch.commit(request_options=request_options) + + self.assertEqual(committed, now) + self.assertEqual(batch.committed, committed) + + if type(request_options) == dict: + expected_request_options = RequestOptions(request_options) + else: + expected_request_options = request_options + expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.request_tag = None + + ( + session, + mutations, + single_use_txn, + actual_request_options, + metadata, + ) = api._committed + self.assertEqual(session, self.SESSION_NAME) + self.assertEqual(mutations, batch._mutations) + self.assertIsInstance(single_use_txn, TransactionOptions) + self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write")) + self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)]) + self.assertEqual(actual_request_options, expected_request_options) + + self.assertSpanAttributes( + "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) + ) + + def test_commit_w_request_tag_success(self): + request_options = RequestOptions(request_tag="tag-1",) + self._test_commit_with_request_options(request_options=request_options) + + def test_commit_w_transaction_tag_success(self): + request_options = RequestOptions(transaction_tag="tag-1-1",) + self._test_commit_with_request_options(request_options=request_options) + + def test_commit_w_request_and_transaction_tag_success(self): + request_options = RequestOptions( + request_tag="tag-1", transaction_tag="tag-1-1", + ) + self._test_commit_with_request_options(request_options=request_options) + + def test_commit_w_request_and_transaction_tag_dictionary_success(self): + request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"} + self._test_commit_with_request_options(request_options=request_options) + + def test_commit_w_incorrect_tag_dictionary_error(self): + request_options = {"incorrect_tag": "tag-1-1"} + with self.assertRaises(ValueError): + self._test_commit_with_request_options(request_options=request_options) + def test_context_mgr_already_committed(self): import datetime from google.cloud._helpers import UTC @@ -281,13 +352,13 @@ def test_context_mgr_success(self): self.assertEqual(batch.committed, now) - (session, mutations, single_use_txn, metadata, request_options) = api._committed + (session, mutations, single_use_txn, request_options, metadata) = api._committed self.assertEqual(session, self.SESSION_NAME) self.assertEqual(mutations, batch._mutations) self.assertIsInstance(single_use_txn, TransactionOptions) self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write")) self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)]) - self.assertEqual(request_options, None) + self.assertEqual(request_options, RequestOptions()) self.assertSpanAttributes( "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) @@ -341,7 +412,7 @@ def __init__(self, **kwargs): self.__dict__.update(**kwargs) def commit( - self, request=None, metadata=None, request_options=None, + self, request=None, metadata=None, ): from google.api_core.exceptions import Unknown @@ -350,8 +421,8 @@ def commit( request.session, request.mutations, request.single_use_transaction, + request.request_options, metadata, - request_options, ) if self._rpc_error: raise Unknown("error") diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index a4b7aa2425..df5554d153 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -61,6 +61,7 @@ class _BaseTest(unittest.TestCase): RETRY_TRANSACTION_ID = b"transaction_id_retry" BACKUP_ID = "backup_id" BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID + TRANSACTION_TAG = "transaction-tag" def _make_one(self, *args, **kwargs): return self._get_target_class()(*args, **kwargs) @@ -1000,6 +1001,11 @@ def _execute_partitioned_dml_helper( expected_query_options, query_options ) + if not request_options: + expected_request_options = RequestOptions() + else: + expected_request_options = RequestOptions(request_options) + expected_request_options.transaction_tag = None expected_request = ExecuteSqlRequest( session=self.SESSION_NAME, sql=dml, @@ -1007,7 +1013,7 @@ def _execute_partitioned_dml_helper( params=expected_params, param_types=param_types, query_options=expected_query_options, - request_options=request_options, + request_options=expected_request_options, ) api.execute_streaming_sql.assert_any_call( @@ -1025,7 +1031,7 @@ def _execute_partitioned_dml_helper( params=expected_params, param_types=param_types, query_options=expected_query_options, - request_options=request_options, + request_options=expected_request_options, ) api.execute_streaming_sql.assert_called_with( request=expected_request, @@ -1063,6 +1069,16 @@ def test_execute_partitioned_dml_w_request_options(self): ), ) + def test_execute_partitioned_dml_w_trx_tag_ignored(self): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, request_options=RequestOptions(transaction_tag="trx-tag"), + ) + + def test_execute_partitioned_dml_w_req_tag_used(self): + self._execute_partitioned_dml_helper( + dml=DML_W_PARAM, request_options=RequestOptions(request_tag="req-tag"), + ) + def test_execute_partitioned_dml_wo_params_retry_aborted(self): self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True) @@ -1560,7 +1576,9 @@ def test_context_mgr_success(self): pool = database._pool = _Pool() session = _Session(database) pool.put(session) - checkout = self._make_one(database) + checkout = self._make_one( + database, request_options={"transaction_tag": self.TRANSACTION_TAG} + ) with checkout as batch: self.assertIsNone(pool._session) @@ -1569,6 +1587,7 @@ def test_context_mgr_success(self): self.assertIs(pool._session, session) self.assertEqual(batch.committed, now) + self.assertEqual(batch.transaction_tag, self.TRANSACTION_TAG) expected_txn_options = TransactionOptions(read_write={}) @@ -1576,6 +1595,7 @@ def test_context_mgr_success(self): session=self.SESSION_NAME, mutations=[], single_use_transaction=expected_txn_options, + request_options=RequestOptions(transaction_tag=self.TRANSACTION_TAG), ) api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], @@ -1618,6 +1638,7 @@ def test_context_mgr_w_commit_stats_success(self): mutations=[], single_use_transaction=expected_txn_options, return_commit_stats=True, + request_options=RequestOptions(), ) api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], @@ -1657,6 +1678,7 @@ def test_context_mgr_w_commit_stats_error(self): mutations=[], single_use_transaction=expected_txn_options, return_commit_stats=True, + request_options=RequestOptions(), ) api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4daabdf952..fe78567f6b 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -14,6 +14,7 @@ import google.api_core.gapic_v1.method +from google.cloud.spanner_v1 import RequestOptions import mock from tests._helpers import ( OpenTelemetryBase, @@ -829,6 +830,7 @@ def unit_of_work(txn, *args, **kw): session=self.SESSION_NAME, mutations=txn._mutations, transaction_id=TRANSACTION_ID, + request_options=RequestOptions(), ) gax_api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], @@ -879,6 +881,7 @@ def unit_of_work(txn, *args, **kw): session=self.SESSION_NAME, mutations=txn._mutations, transaction_id=TRANSACTION_ID, + request_options=RequestOptions(), ) gax_api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], @@ -949,6 +952,7 @@ def unit_of_work(txn, *args, **kw): session=self.SESSION_NAME, mutations=txn._mutations, transaction_id=TRANSACTION_ID, + request_options=RequestOptions(), ) self.assertEqual( gax_api.commit.call_args_list, @@ -1041,6 +1045,7 @@ def unit_of_work(txn, *args, **kw): session=self.SESSION_NAME, mutations=txn._mutations, transaction_id=TRANSACTION_ID, + request_options=RequestOptions(), ) self.assertEqual( gax_api.commit.call_args_list, @@ -1133,6 +1138,7 @@ def unit_of_work(txn, *args, **kw): session=self.SESSION_NAME, mutations=txn._mutations, transaction_id=TRANSACTION_ID, + request_options=RequestOptions(), ) gax_api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], @@ -1223,6 +1229,7 @@ def _time(_results=[1, 1.5]): session=self.SESSION_NAME, mutations=txn._mutations, transaction_id=TRANSACTION_ID, + request_options=RequestOptions(), ) gax_api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], @@ -1304,6 +1311,7 @@ def _time(_results=[1, 2, 4, 8]): session=self.SESSION_NAME, mutations=txn._mutations, transaction_id=TRANSACTION_ID, + request_options=RequestOptions(), ) self.assertEqual( gax_api.commit.call_args_list, @@ -1377,6 +1385,7 @@ def unit_of_work(txn, *args, **kw): mutations=txn._mutations, transaction_id=TRANSACTION_ID, return_commit_stats=True, + request_options=RequestOptions(), ) gax_api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], @@ -1439,12 +1448,81 @@ def unit_of_work(txn, *args, **kw): mutations=txn._mutations, transaction_id=TRANSACTION_ID, return_commit_stats=True, + request_options=RequestOptions(), ) gax_api.commit.assert_called_once_with( request=request, metadata=[("google-cloud-resource-prefix", database.name)], ) database.logger.info.assert_not_called() + def test_run_in_transaction_w_transaction_tag(self): + import datetime + from google.cloud.spanner_v1 import CommitRequest + from google.cloud.spanner_v1 import CommitResponse + from google.cloud.spanner_v1 import ( + Transaction as TransactionPB, + TransactionOptions, + ) + from google.cloud._helpers import UTC + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.spanner_v1.transaction import Transaction + + TABLE_NAME = "citizens" + COLUMNS = ["email", "first_name", "last_name", "age"] + VALUES = [ + ["phred@exammple.com", "Phred", "Phlyntstone", 32], + ["bharney@example.com", "Bharney", "Rhubble", 31], + ] + TRANSACTION_ID = b"FACEDACE" + transaction_pb = TransactionPB(id=TRANSACTION_ID) + now = datetime.datetime.utcnow().replace(tzinfo=UTC) + now_pb = _datetime_to_pb_timestamp(now) + commit_stats = CommitResponse.CommitStats(mutation_count=4) + response = CommitResponse(commit_timestamp=now_pb, commit_stats=commit_stats) + gax_api = self._make_spanner_api() + gax_api.begin_transaction.return_value = transaction_pb + gax_api.commit.return_value = response + database = self._make_database() + database.spanner_api = gax_api + session = self._make_one(database) + session._session_id = self.SESSION_ID + + called_with = [] + + def unit_of_work(txn, *args, **kw): + called_with.append((txn, args, kw)) + txn.insert(TABLE_NAME, COLUMNS, VALUES) + return 42 + + transaction_tag = "transaction_tag" + return_value = session.run_in_transaction( + unit_of_work, "abc", some_arg="def", transaction_tag=transaction_tag + ) + + self.assertIsNone(session._transaction) + self.assertEqual(len(called_with), 1) + txn, args, kw = called_with[0] + self.assertIsInstance(txn, Transaction) + self.assertEqual(return_value, 42) + self.assertEqual(args, ("abc",)) + self.assertEqual(kw, {"some_arg": "def"}) + + expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) + gax_api.begin_transaction.assert_called_once_with( + session=self.SESSION_NAME, + options=expected_options, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + request = CommitRequest( + session=self.SESSION_NAME, + mutations=txn._mutations, + transaction_id=TRANSACTION_ID, + request_options=RequestOptions(transaction_tag=transaction_tag), + ) + gax_api.commit.assert_called_once_with( + request=request, metadata=[("google-cloud-resource-prefix", database.name)], + ) + def test_delay_helper_w_no_delay(self): from google.cloud.spanner_v1.session import _delay_until_retry diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 627b18d910..ef162fd29d 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -402,6 +402,7 @@ def _read_helper( partition=None, timeout=gapic_v1.method.DEFAULT, retry=gapic_v1.method.DEFAULT, + request_options=None, ): from google.protobuf.struct_pb2 import Struct from google.cloud.spanner_v1 import ( @@ -451,6 +452,11 @@ def _read_helper( if not first: derived._transaction_id = TXN_ID + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: + request_options = RequestOptions(request_options) + if partition is not None: # 'limit' and 'partition' incompatible result_set = derived.read( TABLE_NAME, @@ -460,6 +466,7 @@ def _read_helper( partition=partition, retry=retry, timeout=timeout, + request_options=request_options, ) else: result_set = derived.read( @@ -470,6 +477,7 @@ def _read_helper( limit=LIMIT, retry=retry, timeout=timeout, + request_options=request_options, ) self.assertEqual(derived._read_request_count, count + 1) @@ -500,6 +508,10 @@ def _read_helper( else: expected_limit = LIMIT + # Transaction tag is ignored for read request. + expected_request_options = request_options + expected_request_options.transaction_tag = None + expected_request = ReadRequest( session=self.SESSION_NAME, table=TABLE_NAME, @@ -509,6 +521,7 @@ def _read_helper( index=INDEX, limit=expected_limit, partition_token=partition, + request_options=expected_request_options, ) api.streaming_read.assert_called_once_with( request=expected_request, @@ -527,6 +540,29 @@ def _read_helper( def test_read_wo_multi_use(self): self._read_helper(multi_use=False) + def test_read_w_request_tag_success(self): + request_options = RequestOptions(request_tag="tag-1",) + self._read_helper(multi_use=False, request_options=request_options) + + def test_read_w_transaction_tag_success(self): + request_options = RequestOptions(transaction_tag="tag-1-1",) + self._read_helper(multi_use=False, request_options=request_options) + + def test_read_w_request_and_transaction_tag_success(self): + request_options = RequestOptions( + request_tag="tag-1", transaction_tag="tag-1-1", + ) + self._read_helper(multi_use=False, request_options=request_options) + + def test_read_w_request_and_transaction_tag_dictionary_success(self): + request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"} + self._read_helper(multi_use=False, request_options=request_options) + + def test_read_w_incorrect_tag_dictionary_error(self): + request_options = {"incorrect_tag": "tag-1-1"} + with self.assertRaises(ValueError): + self._read_helper(multi_use=False, request_options=request_options) + def test_read_wo_multi_use_w_read_request_count_gt_0(self): with self.assertRaises(ValueError): self._read_helper(multi_use=False, count=1) @@ -646,6 +682,11 @@ def _execute_sql_helper( if not first: derived._transaction_id = TXN_ID + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: + request_options = RequestOptions(request_options) + result_set = derived.execute_sql( SQL_QUERY_WITH_PARAM, PARAMS, @@ -691,6 +732,11 @@ def _execute_sql_helper( expected_query_options, query_options ) + if derived._read_only: + # Transaction tag is ignored for read only requests. + expected_request_options = request_options + expected_request_options.transaction_tag = None + expected_request = ExecuteSqlRequest( session=self.SESSION_NAME, sql=SQL_QUERY_WITH_PARAM, @@ -699,7 +745,7 @@ def _execute_sql_helper( param_types=PARAM_TYPES, query_mode=MODE, query_options=expected_query_options, - request_options=request_options, + request_options=expected_request_options, partition_token=partition, seqno=sql_count, ) @@ -760,6 +806,29 @@ def test_execute_sql_w_request_options(self): ), ) + def test_execute_sql_w_request_tag_success(self): + request_options = RequestOptions(request_tag="tag-1",) + self._execute_sql_helper(multi_use=False, request_options=request_options) + + def test_execute_sql_w_transaction_tag_success(self): + request_options = RequestOptions(transaction_tag="tag-1-1",) + self._execute_sql_helper(multi_use=False, request_options=request_options) + + def test_execute_sql_w_request_and_transaction_tag_success(self): + request_options = RequestOptions( + request_tag="tag-1", transaction_tag="tag-1-1", + ) + self._execute_sql_helper(multi_use=False, request_options=request_options) + + def test_execute_sql_w_request_and_transaction_tag_dictionary_success(self): + request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"} + self._execute_sql_helper(multi_use=False, request_options=request_options) + + def test_execute_sql_w_incorrect_tag_dictionary_error(self): + request_options = {"incorrect_tag": "tag-1-1"} + with self.assertRaises(ValueError): + self._execute_sql_helper(multi_use=False, request_options=request_options) + def _partition_read_helper( self, multi_use, diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d87821fa4a..d11a3495fe 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -51,6 +51,7 @@ class TestTransaction(OpenTelemetryBase): SESSION_ID = "session-id" SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID TRANSACTION_ID = b"DEADBEEF" + TRANSACTION_TAG = "transaction-tag" BASE_ATTRIBUTES = { "db.type": "spanner", @@ -314,7 +315,9 @@ def test_commit_w_other_error(self): attributes=dict(TestTransaction.BASE_ATTRIBUTES, num_mutations=1), ) - def _commit_helper(self, mutate=True, return_commit_stats=False): + def _commit_helper( + self, mutate=True, return_commit_stats=False, request_options=None + ): import datetime from google.cloud.spanner_v1 import CommitResponse from google.cloud.spanner_v1.keyset import KeySet @@ -331,20 +334,38 @@ def _commit_helper(self, mutate=True, return_commit_stats=False): session = _Session(database) transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID + transaction.transaction_tag = self.TRANSACTION_TAG if mutate: transaction.delete(TABLE_NAME, keyset) - transaction.commit(return_commit_stats=return_commit_stats) + transaction.commit( + return_commit_stats=return_commit_stats, request_options=request_options + ) self.assertEqual(transaction.committed, now) self.assertIsNone(session._transaction) - session_id, mutations, txn_id, metadata = api._committed + session_id, mutations, txn_id, actual_request_options, metadata = api._committed + + if request_options is None: + expected_request_options = RequestOptions( + transaction_tag=self.TRANSACTION_TAG + ) + elif type(request_options) == dict: + expected_request_options = RequestOptions(request_options) + expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.request_tag = None + else: + expected_request_options = request_options + expected_request_options.transaction_tag = self.TRANSACTION_TAG + expected_request_options.request_tag = None + self.assertEqual(session_id, session.name) self.assertEqual(txn_id, self.TRANSACTION_ID) self.assertEqual(mutations, transaction._mutations) self.assertEqual(metadata, [("google-cloud-resource-prefix", database.name)]) + self.assertEqual(actual_request_options, expected_request_options) if return_commit_stats: self.assertEqual(transaction.commit_stats.mutation_count, 4) @@ -366,6 +387,29 @@ def test_commit_w_mutations(self): def test_commit_w_return_commit_stats(self): self._commit_helper(return_commit_stats=True) + def test_commit_w_request_tag_success(self): + request_options = RequestOptions(request_tag="tag-1",) + self._commit_helper(request_options=request_options) + + def test_commit_w_transaction_tag_ignored_success(self): + request_options = RequestOptions(transaction_tag="tag-1-1",) + self._commit_helper(request_options=request_options) + + def test_commit_w_request_and_transaction_tag_success(self): + request_options = RequestOptions( + request_tag="tag-1", transaction_tag="tag-1-1", + ) + self._commit_helper(request_options=request_options) + + def test_commit_w_request_and_transaction_tag_dictionary_success(self): + request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"} + self._commit_helper(request_options=request_options) + + def test_commit_w_incorrect_tag_dictionary_error(self): + request_options = {"incorrect_tag": "tag-1-1"} + with self.assertRaises(ValueError): + self._commit_helper(request_options=request_options) + def test__make_params_pb_w_params_wo_param_types(self): session = _Session() transaction = self._make_one(session) @@ -443,8 +487,14 @@ def _execute_update_helper( session = _Session(database) transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID + transaction.transaction_tag = self.TRANSACTION_TAG transaction._execute_sql_count = count + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: + request_options = RequestOptions(request_options) + row_count = transaction.execute_update( DML_QUERY_WITH_PARAM, PARAMS, @@ -468,6 +518,8 @@ def _execute_update_helper( expected_query_options = _merge_query_options( expected_query_options, query_options ) + expected_request_options = request_options + expected_request_options.transaction_tag = self.TRANSACTION_TAG expected_request = ExecuteSqlRequest( session=self.SESSION_NAME, @@ -492,6 +544,29 @@ def _execute_update_helper( def test_execute_update_new_transaction(self): self._execute_update_helper() + def test_execute_update_w_request_tag_success(self): + request_options = RequestOptions(request_tag="tag-1",) + self._execute_update_helper(request_options=request_options) + + def test_execute_update_w_transaction_tag_success(self): + request_options = RequestOptions(transaction_tag="tag-1-1",) + self._execute_update_helper(request_options=request_options) + + def test_execute_update_w_request_and_transaction_tag_success(self): + request_options = RequestOptions( + request_tag="tag-1", transaction_tag="tag-1-1", + ) + self._execute_update_helper(request_options=request_options) + + def test_execute_update_w_request_and_transaction_tag_dictionary_success(self): + request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"} + self._execute_update_helper(request_options=request_options) + + def test_execute_update_w_incorrect_tag_dictionary_error(self): + request_options = {"incorrect_tag": "tag-1-1"} + with self.assertRaises(ValueError): + self._execute_update_helper(request_options=request_options) + def test_execute_update_w_count(self): self._execute_update_helper(count=1) @@ -587,8 +662,14 @@ def _batch_update_helper(self, error_after=None, count=0, request_options=None): session = _Session(database) transaction = self._make_one(session) transaction._transaction_id = self.TRANSACTION_ID + transaction.transaction_tag = self.TRANSACTION_TAG transaction._execute_sql_count = count + if request_options is None: + request_options = RequestOptions() + elif type(request_options) == dict: + request_options = RequestOptions(request_options) + status, row_counts = transaction.batch_update( dml_statements, request_options=request_options ) @@ -611,13 +692,15 @@ def _batch_update_helper(self, error_after=None, count=0, request_options=None): ExecuteBatchDmlRequest.Statement(sql=update_dml), ExecuteBatchDmlRequest.Statement(sql=delete_dml), ] + expected_request_options = request_options + expected_request_options.transaction_tag = self.TRANSACTION_TAG expected_request = ExecuteBatchDmlRequest( session=self.SESSION_NAME, transaction=expected_transaction, statements=expected_statements, seqno=count, - request_options=request_options, + request_options=expected_request_options, ) api.execute_batch_dml.assert_called_once_with( request=expected_request, @@ -633,6 +716,29 @@ def test_batch_update_wo_errors(self): ), ) + def test_batch_update_w_request_tag_success(self): + request_options = RequestOptions(request_tag="tag-1",) + self._batch_update_helper(request_options=request_options) + + def test_batch_update_w_transaction_tag_success(self): + request_options = RequestOptions(transaction_tag="tag-1-1",) + self._batch_update_helper(request_options=request_options) + + def test_batch_update_w_request_and_transaction_tag_success(self): + request_options = RequestOptions( + request_tag="tag-1", transaction_tag="tag-1-1", + ) + self._batch_update_helper(request_options=request_options) + + def test_batch_update_w_request_and_transaction_tag_dictionary_success(self): + request_options = {"request_tag": "tag-1", "transaction_tag": "tag-1-1"} + self._batch_update_helper(request_options=request_options) + + def test_batch_update_w_incorrect_tag_dictionary_error(self): + request_options = {"incorrect_tag": "tag-1-1"} + with self.assertRaises(ValueError): + self._batch_update_helper(request_options=request_options) + def test_batch_update_w_errors(self): self._batch_update_helper(error_after=2, count=1) @@ -688,7 +794,7 @@ def test_context_mgr_success(self): self.assertEqual(transaction.committed, now) - session_id, mutations, txn_id, metadata = api._committed + session_id, mutations, txn_id, _, metadata = api._committed self.assertEqual(session_id, self.SESSION_NAME) self.assertEqual(txn_id, self.TRANSACTION_ID) self.assertEqual(mutations, transaction._mutations) @@ -775,6 +881,7 @@ def commit( request.session, request.mutations, request.transaction_id, + request.request_options, metadata, ) return self._commit_response