Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pass transaction's options to API in 'begin' #143

Merged
merged 4 commits into from May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion google/cloud/datastore/transaction.py
Expand Up @@ -176,10 +176,12 @@ def Entity(*args, **kwargs):
def __init__(self, client, read_only=False):
super(Transaction, self).__init__(client)
self._id = None

if read_only:
options = TransactionOptions(read_only=TransactionOptions.ReadOnly())
else:
options = TransactionOptions()

self._options = options

@property
Expand Down Expand Up @@ -231,9 +233,13 @@ def begin(self, retry=None, timeout=None):

kwargs = _make_retry_timeout_kwargs(retry, timeout)

request = {
"project_id": self.project,
"transaction_options": self._options,
}
try:
response_pb = self._client._datastore_api.begin_transaction(
request={"project_id": self.project}, **kwargs
request=request, **kwargs
)
self._id = response_pb.transaction
except: # noqa: E722 do not use bare except, specify exception instead
Expand Down
135 changes: 89 additions & 46 deletions tests/unit/test_transaction.py
Expand Up @@ -24,28 +24,50 @@ def _get_target_class():

return Transaction

def _get_options_class(self, **kw):
def _make_one(self, client, **kw):
return self._get_target_class()(client, **kw)

def _make_options(self, read_only=False, previous_transaction=None):
from google.cloud.datastore_v1.types import TransactionOptions

return TransactionOptions
kw = {}

def _make_one(self, client, **kw):
return self._get_target_class()(client, **kw)
if read_only:
kw["read_only"] = TransactionOptions.ReadOnly()

def _make_options(self, **kw):
return self._get_options_class()(**kw)
return TransactionOptions(**kw)

def test_ctor_defaults(self):
project = "PROJECT"
client = _Client(project)

xact = self._make_one(client)

self.assertEqual(xact.project, project)
self.assertIs(xact._client, client)
self.assertIsNone(xact.id)
self.assertEqual(xact._status, self._get_target_class()._INITIAL)
self.assertEqual(xact._mutations, [])
self.assertEqual(len(xact._partial_key_entities), 0)

def test_constructor_read_only(self):
project = "PROJECT"
id_ = 850302
ds_api = _make_datastore_api(xact=id_)
client = _Client(project, datastore_api=ds_api)
options = self._make_options(read_only=True)

xact = self._make_one(client, read_only=True)

self.assertEqual(xact._options, options)

def _make_begin_request(self, project, read_only=False):
expected_options = self._make_options(read_only=read_only)
return {
"project_id": project,
"transaction_options": expected_options,
}

def test_current(self):
from google.cloud.datastore_v1.types import datastore as datastore_pb2

Expand All @@ -57,24 +79,34 @@ def test_current(self):
xact2 = self._make_one(client)
self.assertIsNone(xact1.current())
self.assertIsNone(xact2.current())

with xact1:
self.assertIs(xact1.current(), xact1)
self.assertIs(xact2.current(), xact1)

with _NoCommitBatch(client):
self.assertIsNone(xact1.current())
self.assertIsNone(xact2.current())

with xact2:
self.assertIs(xact1.current(), xact2)
self.assertIs(xact2.current(), xact2)

with _NoCommitBatch(client):
self.assertIsNone(xact1.current())
self.assertIsNone(xact2.current())

self.assertIs(xact1.current(), xact1)
self.assertIs(xact2.current(), xact1)

self.assertIsNone(xact1.current())
self.assertIsNone(xact2.current())

ds_api.rollback.assert_not_called()
begin_txn = ds_api.begin_transaction
self.assertEqual(begin_txn.call_count, 2)
expected_request = self._make_begin_request(project)
begin_txn.assert_called_with(request=expected_request)

commit_method = ds_api.commit
self.assertEqual(commit_method.call_count, 2)
mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL
Expand All @@ -87,21 +119,35 @@ def test_current(self):
}
)

begin_txn = ds_api.begin_transaction
self.assertEqual(begin_txn.call_count, 2)
begin_txn.assert_called_with(request={"project_id": project})
ds_api.rollback.assert_not_called()

def test_begin(self):
project = "PROJECT"
id_ = 889
ds_api = _make_datastore_api(xact_id=id_)
client = _Client(project, datastore_api=ds_api)
xact = self._make_one(client)

xact.begin()

self.assertEqual(xact.id, id_)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

def test_begin_w_readonly(self):
project = "PROJECT"
id_ = 889
ds_api = _make_datastore_api(xact_id=id_)
client = _Client(project, datastore_api=ds_api)
xact = self._make_one(client, read_only=True)

xact.begin()

self.assertEqual(xact.id, id_)

expected_request = self._make_begin_request(project, read_only=True)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

def test_begin_w_retry_w_timeout(self):
project = "PROJECT"
Expand All @@ -116,8 +162,10 @@ def test_begin_w_retry_w_timeout(self):
xact.begin(retry=retry, timeout=timeout)

self.assertEqual(xact.id, id_)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}, retry=retry, timeout=timeout
request=expected_request, retry=retry, timeout=timeout,
)

def test_begin_tombstoned(self):
Expand All @@ -126,19 +174,23 @@ def test_begin_tombstoned(self):
ds_api = _make_datastore_api(xact_id=id_)
client = _Client(project, datastore_api=ds_api)
xact = self._make_one(client)

xact.begin()

self.assertEqual(xact.id, id_)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

xact.rollback()

client._datastore_api.rollback.assert_called_once_with(
request={"project_id": project, "transaction": id_}
)
self.assertIsNone(xact.id)

self.assertRaises(ValueError, xact.begin)
with self.assertRaises(ValueError):
xact.begin()

def test_begin_w_begin_transaction_failure(self):
project = "PROJECT"
Expand All @@ -152,9 +204,9 @@ def test_begin_w_begin_transaction_failure(self):
xact.begin()

self.assertIsNone(xact.id)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

def test_rollback(self):
project = "PROJECT"
Expand Down Expand Up @@ -256,11 +308,14 @@ def test_context_manager_no_raise(self):
ds_api = _make_datastore_api(xact_id=id_)
client = _Client(project, datastore_api=ds_api)
xact = self._make_one(client)

with xact:
self.assertEqual(xact.id, id_)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)
self.assertEqual(xact.id, id_) # only set between begin / commit

self.assertIsNone(xact.id)

expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

mode = datastore_pb2.CommitRequest.Mode.TRANSACTIONAL
client._datastore_api.commit.assert_called_once_with(
Expand All @@ -272,9 +327,6 @@ def test_context_manager_no_raise(self):
},
)

self.assertIsNone(xact.id)
self.assertEqual(ds_api.begin_transaction.call_count, 1)

def test_context_manager_w_raise(self):
class Foo(Exception):
pass
Expand All @@ -288,29 +340,20 @@ class Foo(Exception):
try:
with xact:
self.assertEqual(xact.id, id_)
ds_api.begin_transaction.assert_called_once_with(
request={"project_id": project}
)
raise Foo()
except Foo:
self.assertIsNone(xact.id)
client._datastore_api.rollback.assert_called_once_with(
request={"project_id": project, "transaction": id_}
)
pass

client._datastore_api.commit.assert_not_called()
self.assertIsNone(xact.id)
self.assertEqual(ds_api.begin_transaction.call_count, 1)

def test_constructor_read_only(self):
project = "PROJECT"
id_ = 850302
ds_api = _make_datastore_api(xact=id_)
client = _Client(project, datastore_api=ds_api)
read_only = self._get_options_class().ReadOnly()
options = self._make_options(read_only=read_only)
xact = self._make_one(client, read_only=True)
self.assertEqual(xact._options, options)
expected_request = self._make_begin_request(project)
ds_api.begin_transaction.assert_called_once_with(request=expected_request)

client._datastore_api.commit.assert_not_called()

client._datastore_api.rollback.assert_called_once_with(
request={"project_id": project, "transaction": id_}
)

def test_put_read_only(self):
project = "PROJECT"
Expand Down