Skip to content

Commit

Permalink
feat: add 'retry'/'timeout' args to 'Client.get'/'Client.get_multi'
Browse files Browse the repository at this point in the history
Toward #3
  • Loading branch information
tseaver committed Aug 11, 2020
1 parent edaaa45 commit 8dfbb2d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 17 deletions.
69 changes: 66 additions & 3 deletions google/cloud/datastore/client.py
Expand Up @@ -97,6 +97,8 @@ def _extended_lookup(
deferred=None,
eventual=False,
transaction_id=None,
retry=None,
timeout=None,
):
"""Repeat lookup until all keys found (unless stop requested).
Expand Down Expand Up @@ -133,6 +135,17 @@ def _extended_lookup(
the given transaction. Incompatible with
``eventual==True``.
:type retry: :class:`google.api_core.retry.Retry`
:param retry:
A retry object used to retry requests. If ``None`` is specified,
requests will be retried using a default configuration.
:type timeout: float
:param timeout:
Time, in seconds, to wait for the request to complete.
Note that if ``retry`` is specified, the timeout applies
to each individual attempt.
:rtype: list of :class:`.entity_pb2.Entity`
:returns: The requested entities.
:raises: :class:`ValueError` if missing / deferred are not null or
Expand All @@ -144,14 +157,22 @@ def _extended_lookup(
if deferred is not None and deferred != []:
raise ValueError("deferred must be None or an empty list")

kwargs = {}

if retry is not None:
kwargs["retry"] = retry

if timeout is not None:
kwargs["timeout"] = timeout

results = []

loop_num = 0
read_options = helpers.get_read_options(eventual, transaction_id)
while loop_num < _MAX_LOOPS: # loop against possible deferred.
loop_num += 1
lookup_response = datastore_api.lookup(
project, key_pbs, read_options=read_options
project, key_pbs, read_options=read_options, **kwargs
)

# Accumulate the new results.
Expand Down Expand Up @@ -338,7 +359,16 @@ def current_transaction(self):
if isinstance(transaction, Transaction):
return transaction

def get(self, key, missing=None, deferred=None, transaction=None, eventual=False):
def get(
self,
key,
missing=None,
deferred=None,
transaction=None,
eventual=False,
retry=None,
timeout=None,
):
"""Retrieve an entity from a single key (if it exists).
.. note::
Expand Down Expand Up @@ -369,6 +399,17 @@ def get(self, key, missing=None, deferred=None, transaction=None, eventual=False
Setting True will use eventual consistency, but cannot
be used inside a transaction or will raise ValueError.
:type retry: :class:`google.api_core.retry.Retry`
:param retry:
A retry object used to retry requests. If ``None`` is specified,
requests will be retried using a default configuration.
:type timeout: float
:param timeout:
Time, in seconds, to wait for the request to complete.
Note that if ``retry`` is specified, the timeout applies
to each individual attempt.
:rtype: :class:`google.cloud.datastore.entity.Entity` or ``NoneType``
:returns: The requested entity if it exists.
Expand All @@ -380,12 +421,21 @@ def get(self, key, missing=None, deferred=None, transaction=None, eventual=False
deferred=deferred,
transaction=transaction,
eventual=eventual,
retry=retry,
timeout=timeout,
)
if entities:
return entities[0]

def get_multi(
self, keys, missing=None, deferred=None, transaction=None, eventual=False
self,
keys,
missing=None,
deferred=None,
transaction=None,
eventual=False,
retry=None,
timeout=None,
):
"""Retrieve entities, along with their attributes.
Expand All @@ -412,6 +462,17 @@ def get_multi(
Setting True will use eventual consistency, but cannot
be used inside a transaction or will raise ValueError.
:type retry: :class:`google.api_core.retry.Retry`
:param retry:
A retry object used to retry requests. If ``None`` is specified,
requests will be retried using a default configuration.
:type timeout: float
:param timeout:
Time, in seconds, to wait for the request to complete.
Note that if ``retry`` is specified, the timeout applies
to each individual attempt.
:rtype: list of :class:`google.cloud.datastore.entity.Entity`
:returns: The requested entities.
:raises: :class:`ValueError` if one or more of ``keys`` has a project
Expand All @@ -437,6 +498,8 @@ def get_multi(
missing=missing,
deferred=deferred,
transaction_id=transaction and transaction.id,
retry=retry,
timeout=timeout,
)

if missing is not None:
Expand Down
33 changes: 19 additions & 14 deletions tests/unit/test_client.py
Expand Up @@ -356,25 +356,24 @@ def test__push_batch_and__pop_batch(self):
self.assertEqual(list(client._batch_stack), [])

def test_get_miss(self):
_called_with = []

def _get_multi(*args, **kw):
_called_with.append((args, kw))
return []

creds = _make_credentials()
client = self._make_one(credentials=creds)
client.get_multi = _get_multi
get_multi = client.get_multi = mock.Mock(return_value=[])

key = object()

self.assertIsNone(client.get(key))

self.assertEqual(_called_with[0][0], ())
self.assertEqual(_called_with[0][1]["keys"], [key])
self.assertIsNone(_called_with[0][1]["missing"])
self.assertIsNone(_called_with[0][1]["deferred"])
self.assertIsNone(_called_with[0][1]["transaction"])
get_multi.assert_called_once_with(
keys=[key],
missing=None,
deferred=None,
transaction=None,
eventual=False,
retry=None,
timeout=None,
)

def test_get_hit(self):
TXN_ID = "123"
Expand Down Expand Up @@ -554,13 +553,15 @@ def test_get_multi_w_deferred_from_backend_but_not_passed(self):
self.PROJECT, [key1_pb, key2_pb], read_options=read_options
)

def test_get_multi_hit(self):
def test_get_multi_hit_w_retry_w_timeout(self):
from google.cloud.datastore_v1.proto import datastore_pb2
from google.cloud.datastore.key import Key

kind = "Kind"
id_ = 1234
path = [{"kind": kind, "id": id_}]
retry = mock.Mock()
timeout = 100000

# Make a found entity pb to be returned from mock backend.
entity_pb = _make_entity_pb(self.PROJECT, kind, id_, "foo", "Foo")
Expand All @@ -573,7 +574,7 @@ def test_get_multi_hit(self):
client._datastore_api_internal = ds_api

key = Key(kind, id_, project=self.PROJECT)
(result,) = client.get_multi([key])
(result,) = client.get_multi([key], retry=retry, timeout=timeout)
new_key = result.key

# Check the returned value is as expected.
Expand All @@ -585,7 +586,11 @@ def test_get_multi_hit(self):

read_options = datastore_pb2.ReadOptions()
ds_api.lookup.assert_called_once_with(
self.PROJECT, [key.to_protobuf()], read_options=read_options
self.PROJECT,
[key.to_protobuf()],
read_options=read_options,
retry=retry,
timeout=timeout,
)

def test_get_multi_hit_w_transaction(self):
Expand Down

0 comments on commit 8dfbb2d

Please sign in to comment.