diff --git a/google/api_core/grpc_helpers.py b/google/api_core/grpc_helpers.py index 4d63beb3..c47b09fd 100644 --- a/google/api_core/grpc_helpers.py +++ b/google/api_core/grpc_helpers.py @@ -65,6 +65,19 @@ class _StreamingResponseIterator(grpc.Call): def __init__(self, wrapped): self._wrapped = wrapped + # This iterator is used in a retry context, and returned outside after init. + # gRPC will not throw an exception until the stream is consumed, so we need + # to retrieve the first result, in order to fail, in order to trigger a retry. + try: + self._stored_first_result = six.next(self._wrapped) + except TypeError: + # It is possible the wrapped method isn't an iterable (a grpc.Call + # for instance). If this happens don't store the first result. + pass + except StopIteration: + # ignore stop iteration at this time. This should be handled outside of retry. + pass + def __iter__(self): """This iterator is also an iterable that returns itself.""" return self @@ -76,8 +89,13 @@ def next(self): protobuf.Message: A single response from the stream. """ try: + if hasattr(self, "_stored_first_result"): + result = self._stored_first_result + del self._stored_first_result + return result return six.next(self._wrapped) except grpc.RpcError as exc: + # If the stream has already returned data, we cannot recover here. six.raise_from(exceptions.from_grpc_error(exc), exc) # Alias needed for Python 2/3 support. diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py index c37c3eed..1fec64f7 100644 --- a/tests/unit/test_grpc_helpers.py +++ b/tests/unit/test_grpc_helpers.py @@ -129,24 +129,55 @@ def test_wrap_stream_errors_invocation(): assert exc_info.value.response == grpc_error +def test_wrap_stream_empty_iterator(): + expected_responses = [] + callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses)) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + + got_iterator = wrapped_callable() + + responses = list(got_iterator) + + callable_.assert_called_once_with() + assert responses == expected_responses + + class RpcResponseIteratorImpl(object): - def __init__(self, exception): - self._exception = exception + def __init__(self, iterable): + self._iterable = iter(iterable) def next(self): - raise self._exception + next_item = next(self._iterable) + if isinstance(next_item, RpcErrorImpl): + raise next_item + return next_item __next__ = next -def test_wrap_stream_errors_iterator(): +def test_wrap_stream_errors_iterator_initialization(): grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE) - response_iter = RpcResponseIteratorImpl(grpc_error) + response_iter = RpcResponseIteratorImpl([grpc_error]) callable_ = mock.Mock(spec=["__call__"], return_value=response_iter) wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) + with pytest.raises(exceptions.ServiceUnavailable) as exc_info: + wrapped_callable(1, 2, three="four") + + callable_.assert_called_once_with(1, 2, three="four") + assert exc_info.value.response == grpc_error + + +def test_wrap_stream_errors_during_iteration(): + grpc_error = RpcErrorImpl(grpc.StatusCode.UNAVAILABLE) + response_iter = RpcResponseIteratorImpl([1, grpc_error]) + callable_ = mock.Mock(spec=["__call__"], return_value=response_iter) + + wrapped_callable = grpc_helpers._wrap_stream_errors(callable_) got_iterator = wrapped_callable(1, 2, three="four") + next(got_iterator) with pytest.raises(exceptions.ServiceUnavailable) as exc_info: next(got_iterator)