diff --git a/google/cloud/datastore/transaction.py b/google/cloud/datastore/transaction.py index a1eabed5..21cac1a7 100644 --- a/google/cloud/datastore/transaction.py +++ b/google/cloud/datastore/transaction.py @@ -176,10 +176,12 @@ def Entity(*args, **kwargs): def __init__(self, client, read_only=False): super(Transaction, self).__init__(client) self._id = None + if read_only: options = TransactionOptions(read_only=TransactionOptions.ReadOnly()) else: options = TransactionOptions() + self._options = options @property @@ -231,9 +233,13 @@ def begin(self, retry=None, timeout=None): kwargs = _make_retry_timeout_kwargs(retry, timeout) + request = { + "project_id": self.project, + "transaction_options": self._options, + } try: response_pb = self._client._datastore_api.begin_transaction( - request={"project_id": self.project}, **kwargs + request=request, **kwargs ) self._id = response_pb.transaction except: # noqa: E722 do not use bare except, specify exception instead diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index 1bc355cc..bae419df 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -24,21 +24,25 @@ def _get_target_class(): return Transaction - def _get_options_class(self, **kw): + def _make_one(self, client, **kw): + return self._get_target_class()(client, **kw) + + def _make_options(self, read_only=False, previous_transaction=None): from google.cloud.datastore_v1.types import TransactionOptions - return TransactionOptions + kw = {} - def _make_one(self, client, **kw): - return self._get_target_class()(client, **kw) + if read_only: + kw["read_only"] = TransactionOptions.ReadOnly() - def _make_options(self, **kw): - return self._get_options_class()(**kw) + return TransactionOptions(**kw) def test_ctor_defaults(self): project = "PROJECT" client = _Client(project) + xact = self._make_one(client) + self.assertEqual(xact.project, project) self.assertIs(xact._client, client) self.assertIsNone(xact.id) @@ -46,6 +50,24 @@ def test_ctor_defaults(self): self.assertEqual(xact._mutations, []) self.assertEqual(len(xact._partial_key_entities), 0) + def test_constructor_read_only(self): + project = "PROJECT" + id_ = 850302 + ds_api = _make_datastore_api(xact=id_) + client = _Client(project, datastore_api=ds_api) + options = self._make_options(read_only=True) + + xact = self._make_one(client, read_only=True) + + self.assertEqual(xact._options, options) + + def _make_begin_request(self, project, read_only=False): + expected_options = self._make_options(read_only=read_only) + return { + "project_id": project, + "transaction_options": expected_options, + } + def test_current(self): from google.cloud.datastore_v1.types import datastore as datastore_pb2 @@ -57,24 +79,34 @@ def test_current(self): xact2 = self._make_one(client) self.assertIsNone(xact1.current()) self.assertIsNone(xact2.current()) + with xact1: self.assertIs(xact1.current(), xact1) self.assertIs(xact2.current(), xact1) + with _NoCommitBatch(client): self.assertIsNone(xact1.current()) self.assertIsNone(xact2.current()) + with xact2: self.assertIs(xact1.current(), xact2) self.assertIs(xact2.current(), xact2) + with _NoCommitBatch(client): self.assertIsNone(xact1.current()) self.assertIsNone(xact2.current()) + self.assertIs(xact1.current(), xact1) self.assertIs(xact2.current(), xact1) + self.assertIsNone(xact1.current()) self.assertIsNone(xact2.current()) - ds_api.rollback.assert_not_called() + begin_txn = ds_api.begin_transaction + self.assertEqual(begin_txn.call_count, 2) + expected_request = self._make_begin_request(project) + begin_txn.assert_called_with(request=expected_request) + commit_method = ds_api.commit self.assertEqual(commit_method.call_count, 2) mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL @@ -87,9 +119,7 @@ def test_current(self): } ) - begin_txn = ds_api.begin_transaction - self.assertEqual(begin_txn.call_count, 2) - begin_txn.assert_called_with(request={"project_id": project}) + ds_api.rollback.assert_not_called() def test_begin(self): project = "PROJECT" @@ -97,11 +127,27 @@ def test_begin(self): ds_api = _make_datastore_api(xact_id=id_) client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) + xact.begin() + self.assertEqual(xact.id, id_) - ds_api.begin_transaction.assert_called_once_with( - request={"project_id": project} - ) + + expected_request = self._make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) + + def test_begin_w_readonly(self): + project = "PROJECT" + id_ = 889 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) + xact = self._make_one(client, read_only=True) + + xact.begin() + + self.assertEqual(xact.id, id_) + + expected_request = self._make_begin_request(project, read_only=True) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) def test_begin_w_retry_w_timeout(self): project = "PROJECT" @@ -116,8 +162,10 @@ def test_begin_w_retry_w_timeout(self): xact.begin(retry=retry, timeout=timeout) self.assertEqual(xact.id, id_) + + expected_request = self._make_begin_request(project) ds_api.begin_transaction.assert_called_once_with( - request={"project_id": project}, retry=retry, timeout=timeout + request=expected_request, retry=retry, timeout=timeout, ) def test_begin_tombstoned(self): @@ -126,19 +174,23 @@ def test_begin_tombstoned(self): ds_api = _make_datastore_api(xact_id=id_) client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) + xact.begin() + self.assertEqual(xact.id, id_) - ds_api.begin_transaction.assert_called_once_with( - request={"project_id": project} - ) + + expected_request = self._make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) xact.rollback() + client._datastore_api.rollback.assert_called_once_with( request={"project_id": project, "transaction": id_} ) self.assertIsNone(xact.id) - self.assertRaises(ValueError, xact.begin) + with self.assertRaises(ValueError): + xact.begin() def test_begin_w_begin_transaction_failure(self): project = "PROJECT" @@ -152,9 +204,9 @@ def test_begin_w_begin_transaction_failure(self): xact.begin() self.assertIsNone(xact.id) - ds_api.begin_transaction.assert_called_once_with( - request={"project_id": project} - ) + + expected_request = self._make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) def test_rollback(self): project = "PROJECT" @@ -256,11 +308,14 @@ def test_context_manager_no_raise(self): ds_api = _make_datastore_api(xact_id=id_) client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) + with xact: - self.assertEqual(xact.id, id_) - ds_api.begin_transaction.assert_called_once_with( - request={"project_id": project} - ) + self.assertEqual(xact.id, id_) # only set between begin / commit + + self.assertIsNone(xact.id) + + expected_request = self._make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL client._datastore_api.commit.assert_called_once_with( @@ -272,9 +327,6 @@ def test_context_manager_no_raise(self): }, ) - self.assertIsNone(xact.id) - self.assertEqual(ds_api.begin_transaction.call_count, 1) - def test_context_manager_w_raise(self): class Foo(Exception): pass @@ -288,29 +340,20 @@ class Foo(Exception): try: with xact: self.assertEqual(xact.id, id_) - ds_api.begin_transaction.assert_called_once_with( - request={"project_id": project} - ) raise Foo() except Foo: - self.assertIsNone(xact.id) - client._datastore_api.rollback.assert_called_once_with( - request={"project_id": project, "transaction": id_} - ) + pass - client._datastore_api.commit.assert_not_called() self.assertIsNone(xact.id) - self.assertEqual(ds_api.begin_transaction.call_count, 1) - def test_constructor_read_only(self): - project = "PROJECT" - id_ = 850302 - ds_api = _make_datastore_api(xact=id_) - client = _Client(project, datastore_api=ds_api) - read_only = self._get_options_class().ReadOnly() - options = self._make_options(read_only=read_only) - xact = self._make_one(client, read_only=True) - self.assertEqual(xact._options, options) + expected_request = self._make_begin_request(project) + ds_api.begin_transaction.assert_called_once_with(request=expected_request) + + client._datastore_api.commit.assert_not_called() + + client._datastore_api.rollback.assert_called_once_with( + request={"project_id": project, "transaction": id_} + ) def test_put_read_only(self): project = "PROJECT"