diff --git a/google/api_core/future/polling.py b/google/api_core/future/polling.py index 6b4c687d..6466838f 100644 --- a/google/api_core/future/polling.py +++ b/google/api_core/future/polling.py @@ -78,16 +78,18 @@ def done(self, retry=DEFAULT_RETRY): # pylint: disable=redundant-returns-doc, missing-raises-doc raise NotImplementedError() - def _done_or_raise(self): + def _done_or_raise(self, retry=DEFAULT_RETRY): """Check if the future is done and raise if it's not.""" - if not self.done(): + kwargs = {} if retry is DEFAULT_RETRY else {"retry": retry} + + if not self.done(**kwargs): raise _OperationNotComplete() def running(self): """True if the operation is currently running.""" return not self.done() - def _blocking_poll(self, timeout=None): + def _blocking_poll(self, timeout=None, retry=DEFAULT_RETRY): """Poll and wait for the Future to be resolved. Args: @@ -101,13 +103,14 @@ def _blocking_poll(self, timeout=None): retry_ = self._retry.with_deadline(timeout) try: - retry_(self._done_or_raise)() + kwargs = {} if retry is DEFAULT_RETRY else {"retry": retry} + retry_(self._done_or_raise)(**kwargs) except exceptions.RetryError: raise concurrent.futures.TimeoutError( "Operation did not complete within the designated " "timeout." ) - def result(self, timeout=None): + def result(self, timeout=None, retry=DEFAULT_RETRY): """Get the result of the operation, blocking if necessary. Args: @@ -122,7 +125,8 @@ def result(self, timeout=None): google.api_core.GoogleAPICallError: If the operation errors or if the timeout is reached before the operation completes. """ - self._blocking_poll(timeout=timeout) + kwargs = {} if retry is DEFAULT_RETRY else {"retry": retry} + self._blocking_poll(timeout=timeout, **kwargs) if self._exception is not None: # pylint: disable=raising-bad-type diff --git a/tests/unit/future/test_polling.py b/tests/unit/future/test_polling.py index c67de064..2381d036 100644 --- a/tests/unit/future/test_polling.py +++ b/tests/unit/future/test_polling.py @@ -19,7 +19,7 @@ import mock import pytest -from google.api_core import exceptions +from google.api_core import exceptions, retry from google.api_core.future import polling @@ -43,6 +43,8 @@ def test_polling_future_constructor(): assert not future.cancelled() assert future.running() assert future.cancel() + with mock.patch.object(future, "done", return_value=True): + future.result() def test_set_result(): @@ -87,7 +89,7 @@ def __init__(self): self.poll_count = 0 self.event = threading.Event() - def done(self): + def done(self, retry=polling.DEFAULT_RETRY): self.poll_count += 1 self.event.wait() self.set_result(42) @@ -108,7 +110,7 @@ def test_result_with_polling(): class PollingFutureImplTimeout(PollingFutureImplWithPoll): - def done(self): + def done(self, retry=polling.DEFAULT_RETRY): time.sleep(1) return False @@ -130,7 +132,7 @@ def __init__(self, errors): super(PollingFutureImplTransient, self).__init__() self._errors = errors - def done(self): + def done(self, retry=polling.DEFAULT_RETRY): if self._errors: error, self._errors = self._errors[0], self._errors[1:] raise error("testing") @@ -192,3 +194,49 @@ def test_double_callback_background_thread(): assert future.poll_count == 1 callback.assert_called_once_with(future) callback2.assert_called_once_with(future) + + +class PollingFutureImplWithoutRetry(PollingFutureImpl): + def done(self): + return True + + def result(self): + return super(PollingFutureImplWithoutRetry, self).result() + + def _blocking_poll(self, timeout): + return super(PollingFutureImplWithoutRetry, self)._blocking_poll( + timeout=timeout + ) + + +class PollingFutureImplWith_done_or_raise(PollingFutureImpl): + def done(self): + return True + + def _done_or_raise(self): + return super(PollingFutureImplWith_done_or_raise, self)._done_or_raise() + + +def test_polling_future_without_retry(): + custom_retry = retry.Retry( + predicate=retry.if_exception_type(exceptions.TooManyRequests) + ) + future = PollingFutureImplWithoutRetry() + assert future.done() + assert future.running() + assert future.result() is None + + with mock.patch.object(future, "done") as done_mock: + future._done_or_raise() + done_mock.assert_called_once_with() + + with mock.patch.object(future, "done") as done_mock: + future._done_or_raise(retry=custom_retry) + done_mock.assert_called_once_with(retry=custom_retry) + + +def test_polling_future_with__done_or_raise(): + future = PollingFutureImplWith_done_or_raise() + assert future.done() + assert future.running() + assert future.result() is None