diff --git a/docs/futures.rst b/docs/futures.rst index 7a43da9d..d0dadac5 100644 --- a/docs/futures.rst +++ b/docs/futures.rst @@ -7,4 +7,8 @@ Futures .. automodule:: google.api_core.future.polling :members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + +.. automodule:: google.api_core.future.async_future + :members: + :show-inheritance: diff --git a/docs/operation.rst b/docs/operation.rst index c5e67662..492cf67e 100644 --- a/docs/operation.rst +++ b/docs/operation.rst @@ -4,3 +4,10 @@ Long-Running Operations .. automodule:: google.api_core.operation :members: :show-inheritance: + +Long-Running Operations in AsyncIO +------------------------------------- + +.. automodule:: google.api_core.operation_async + :members: + :show-inheritance: diff --git a/docs/page_iterator.rst b/docs/page_iterator.rst index 28842da2..3652e6d5 100644 --- a/docs/page_iterator.rst +++ b/docs/page_iterator.rst @@ -4,3 +4,10 @@ Page Iterators .. automodule:: google.api_core.page_iterator :members: :show-inheritance: + +Page Iterators in AsyncIO +------------------------- + +.. automodule:: google.api_core.page_iterator_async + :members: + :show-inheritance: diff --git a/google/api_core/future/async_future.py b/google/api_core/future/async_future.py new file mode 100644 index 00000000..e1d158d0 --- /dev/null +++ b/google/api_core/future/async_future.py @@ -0,0 +1,157 @@ +# Copyright 2020, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIO implementation of the abstract base Future class.""" + +import asyncio + +from google.api_core import exceptions +from google.api_core import retry +from google.api_core import retry_async +from google.api_core.future import base + + +class _OperationNotComplete(Exception): + """Private exception used for polling via retry.""" + pass + + +RETRY_PREDICATE = retry.if_exception_type( + _OperationNotComplete, + exceptions.TooManyRequests, + exceptions.InternalServerError, + exceptions.BadGateway, +) +DEFAULT_RETRY = retry_async.AsyncRetry(predicate=RETRY_PREDICATE) + + +class AsyncFuture(base.Future): + """A Future that polls peer service to self-update. + + The :meth:`done` method should be implemented by subclasses. The polling + behavior will repeatedly call ``done`` until it returns True. + + .. note: Privacy here is intended to prevent the final class from + overexposing, not to prevent subclasses from accessing methods. + + Args: + retry (google.api_core.retry.Retry): The retry configuration used + when polling. This can be used to control how often :meth:`done` + is polled. Regardless of the retry's ``deadline``, it will be + overridden by the ``timeout`` argument to :meth:`result`. + """ + + def __init__(self, retry=DEFAULT_RETRY): + super().__init__() + self._retry = retry + self._future = asyncio.get_event_loop().create_future() + self._background_task = None + + async def done(self, retry=DEFAULT_RETRY): + """Checks to see if the operation is complete. + + Args: + retry (google.api_core.retry.Retry): (Optional) How to retry the RPC. + + Returns: + bool: True if the operation is complete, False otherwise. + """ + # pylint: disable=redundant-returns-doc, missing-raises-doc + raise NotImplementedError() + + async def _done_or_raise(self): + """Check if the future is done and raise if it's not.""" + result = await self.done() + if not result: + raise _OperationNotComplete() + + async def running(self): + """True if the operation is currently running.""" + result = await self.done() + return not result + + async def _blocking_poll(self, timeout=None): + """Poll and await for the Future to be resolved. + + Args: + timeout (int): + How long (in seconds) to wait for the operation to complete. + If None, wait indefinitely. + """ + if self._future.done(): + return + + retry_ = self._retry.with_deadline(timeout) + + try: + await retry_(self._done_or_raise)() + except exceptions.RetryError: + raise asyncio.TimeoutError( + "Operation did not complete within the designated " "timeout." + ) + + async def result(self, timeout=None): + """Get the result of the operation. + + Args: + timeout (int): + How long (in seconds) to wait for the operation to complete. + If None, wait indefinitely. + + Returns: + google.protobuf.Message: The Operation's result. + + Raises: + google.api_core.GoogleAPICallError: If the operation errors or if + the timeout is reached before the operation completes. + """ + await self._blocking_poll(timeout=timeout) + return self._future.result() + + async def exception(self, timeout=None): + """Get the exception from the operation. + + Args: + timeout (int): How long to wait for the operation to complete. + If None, wait indefinitely. + + Returns: + Optional[google.api_core.GoogleAPICallError]: The operation's + error. + """ + await self._blocking_poll(timeout=timeout) + return self._future.exception() + + def add_done_callback(self, fn): + """Add a callback to be executed when the operation is complete. + + If the operation is completed, the callback will be scheduled onto the + event loop. Otherwise, the callback will be stored and invoked when the + future is done. + + Args: + fn (Callable[Future]): The callback to execute when the operation + is complete. + """ + if self._background_task is None: + self._background_task = asyncio.get_event_loop().create_task(self._blocking_poll()) + self._future.add_done_callback(fn) + + def set_result(self, result): + """Set the Future's result.""" + self._future.set_result(result) + + def set_exception(self, exception): + """Set the Future's exception.""" + self._future.set_exception(exception) diff --git a/google/api_core/operation_async.py b/google/api_core/operation_async.py new file mode 100644 index 00000000..89500af1 --- /dev/null +++ b/google/api_core/operation_async.py @@ -0,0 +1,215 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIO futures for long-running operations returned from Google Cloud APIs. + +These futures can be used to await for the result of a long-running operation +using :meth:`AsyncOperation.result`: + + +.. code-block:: python + + operation = my_api_client.long_running_method() + result = await operation.result() + +Or asynchronously using callbacks and :meth:`Operation.add_done_callback`: + +.. code-block:: python + + operation = my_api_client.long_running_method() + + def my_callback(future): + result = await future.result() + + operation.add_done_callback(my_callback) + +""" + +import functools +import threading + +from google.api_core import exceptions +from google.api_core import protobuf_helpers +from google.api_core.future import async_future +from google.longrunning import operations_pb2 +from google.rpc import code_pb2 + + +class AsyncOperation(async_future.AsyncFuture): + """A Future for interacting with a Google API Long-Running Operation. + + Args: + operation (google.longrunning.operations_pb2.Operation): The + initial operation. + refresh (Callable[[], ~.api_core.operation.Operation]): A callable that + returns the latest state of the operation. + cancel (Callable[[], None]): A callable that tries to cancel + the operation. + result_type (func:`type`): The protobuf type for the operation's + result. + metadata_type (func:`type`): The protobuf type for the operation's + metadata. + retry (google.api_core.retry.Retry): The retry configuration used + when polling. This can be used to control how often :meth:`done` + is polled. Regardless of the retry's ``deadline``, it will be + overridden by the ``timeout`` argument to :meth:`result`. + """ + + def __init__( + self, + operation, + refresh, + cancel, + result_type, + metadata_type=None, + retry=async_future.DEFAULT_RETRY, + ): + super().__init__(retry=retry) + self._operation = operation + self._refresh = refresh + self._cancel = cancel + self._result_type = result_type + self._metadata_type = metadata_type + self._completion_lock = threading.Lock() + # Invoke this in case the operation came back already complete. + self._set_result_from_operation() + + @property + def operation(self): + """google.longrunning.Operation: The current long-running operation.""" + return self._operation + + @property + def metadata(self): + """google.protobuf.Message: the current operation metadata.""" + if not self._operation.HasField("metadata"): + return None + + return protobuf_helpers.from_any_pb( + self._metadata_type, self._operation.metadata + ) + + @classmethod + def deserialize(cls, payload): + """Deserialize a ``google.longrunning.Operation`` protocol buffer. + + Args: + payload (bytes): A serialized operation protocol buffer. + + Returns: + ~.operations_pb2.Operation: An Operation protobuf object. + """ + return operations_pb2.Operation.FromString(payload) + + def _set_result_from_operation(self): + """Set the result or exception from the operation if it is complete.""" + # This must be done in a lock to prevent the async_future thread + # and main thread from both executing the completion logic + # at the same time. + with self._completion_lock: + # If the operation isn't complete or if the result has already been + # set, do not call set_result/set_exception again. + if not self._operation.done or self._future.done(): + return + + if self._operation.HasField("response"): + response = protobuf_helpers.from_any_pb( + self._result_type, self._operation.response + ) + self.set_result(response) + elif self._operation.HasField("error"): + exception = exceptions.GoogleAPICallError( + self._operation.error.message, + errors=(self._operation.error,), + response=self._operation, + ) + self.set_exception(exception) + else: + exception = exceptions.GoogleAPICallError( + "Unexpected state: Long-running operation had neither " + "response nor error set." + ) + self.set_exception(exception) + + async def _refresh_and_update(self, retry=async_future.DEFAULT_RETRY): + """Refresh the operation and update the result if needed. + + Args: + retry (google.api_core.retry.Retry): (Optional) How to retry the RPC. + """ + # If the currently cached operation is done, no need to make another + # RPC as it will not change once done. + if not self._operation.done: + self._operation = await self._refresh(retry=retry) + self._set_result_from_operation() + + async def done(self, retry=async_future.DEFAULT_RETRY): + """Checks to see if the operation is complete. + + Args: + retry (google.api_core.retry.Retry): (Optional) How to retry the RPC. + + Returns: + bool: True if the operation is complete, False otherwise. + """ + await self._refresh_and_update(retry) + return self._operation.done + + async def cancel(self): + """Attempt to cancel the operation. + + Returns: + bool: True if the cancel RPC was made, False if the operation is + already complete. + """ + result = await self.done() + if result: + return False + else: + await self._cancel() + return True + + async def cancelled(self): + """True if the operation was cancelled.""" + await self._refresh_and_update() + return ( + self._operation.HasField("error") + and self._operation.error.code == code_pb2.CANCELLED + ) + + +def from_gapic(operation, operations_client, result_type, **kwargs): + """Create an operation future from a gapic client. + + This interacts with the long-running operations `service`_ (specific + to a given API) via a gapic client. + + .. _service: https://github.com/googleapis/googleapis/blob/\ + 050400df0fdb16f63b63e9dee53819044bffc857/\ + google/longrunning/operations.proto#L38 + + Args: + operation (google.longrunning.operations_pb2.Operation): The operation. + operations_client (google.api_core.operations_v1.OperationsClient): + The operations client. + result_type (:func:`type`): The protobuf result type. + kwargs: Keyword args passed into the :class:`Operation` constructor. + + Returns: + ~.api_core.operation.Operation: The operation future to track the given + operation. + """ + refresh = functools.partial(operations_client.get_operation, operation.name) + cancel = functools.partial(operations_client.cancel_operation, operation.name) + return AsyncOperation(operation, refresh, cancel, result_type, **kwargs) diff --git a/google/api_core/page_iterator_async.py b/google/api_core/page_iterator_async.py new file mode 100644 index 00000000..a0aa41a7 --- /dev/null +++ b/google/api_core/page_iterator_async.py @@ -0,0 +1,278 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AsyncIO iterators for paging through paged API methods. + +These iterators simplify the process of paging through API responses +where the request takes a page token and the response is a list of results with +a token for the next page. See `list pagination`_ in the Google API Style Guide +for more details. + +.. _list pagination: + https://cloud.google.com/apis/design/design_patterns#list_pagination + +API clients that have methods that follow the list pagination pattern can +return an :class:`.AsyncIterator`: + + >>> results_iterator = await client.list_resources() + +Or you can walk your way through items and call off the search early if +you find what you're looking for (resulting in possibly fewer requests):: + + >>> async for resource in results_iterator: + ... print(resource.name) + ... if not resource.is_valid: + ... break + +At any point, you may check the number of items consumed by referencing the +``num_results`` property of the iterator:: + + >>> async for my_item in results_iterator: + ... if results_iterator.num_results >= 10: + ... break + +When iterating, not every new item will send a request to the server. +To iterate based on each page of items (where a page corresponds to +a request):: + + >>> async for page in results_iterator.pages: + ... print('=' * 20) + ... print(' Page number: {:d}'.format(iterator.page_number)) + ... print(' Items in page: {:d}'.format(page.num_items)) + ... print(' First item: {!r}'.format(next(page))) + ... print('Items remaining: {:d}'.format(page.remaining)) + ... print('Next page token: {}'.format(iterator.next_page_token)) + ==================== + Page number: 1 + Items in page: 1 + First item: + Items remaining: 0 + Next page token: eav1OzQB0OM8rLdGXOEsyQWSG + ==================== + Page number: 2 + Items in page: 19 + First item: + Items remaining: 18 + Next page token: None +""" + +import abc + +from google.api_core.page_iterator import Page + + +def _item_to_value_identity(iterator, item): + """An item to value transformer that returns the item un-changed.""" + # pylint: disable=unused-argument + # We are conforming to the interface defined by Iterator. + return item + + +class AsyncIterator(abc.ABC): + """A generic class for iterating through API list responses. + + Args: + client(google.cloud.client.Client): The API client. + item_to_value (Callable[google.api_core.page_iterator_async.AsyncIterator, Any]): + Callable to convert an item from the type in the raw API response + into the native object. Will be called with the iterator and a + single item. + page_token (str): A token identifying a page in a result set to start + fetching results from. + max_results (int): The maximum number of results to fetch. + """ + + def __init__( + self, + client, + item_to_value=_item_to_value_identity, + page_token=None, + max_results=None, + ): + self._started = False + self.client = client + """Optional[Any]: The client that created this iterator.""" + self.item_to_value = item_to_value + """Callable[Iterator, Any]: Callable to convert an item from the type + in the raw API response into the native object. Will be called with + the iterator and a + single item. + """ + self.max_results = max_results + """int: The maximum number of results to fetch.""" + + # The attributes below will change over the life of the iterator. + self.page_number = 0 + """int: The current page of results.""" + self.next_page_token = page_token + """str: The token for the next page of results. If this is set before + the iterator starts, it effectively offsets the iterator to a + specific starting point.""" + self.num_results = 0 + """int: The total number of results fetched so far.""" + + @property + def pages(self): + """Iterator of pages in the response. + + returns: + types.GeneratorType[google.api_core.page_iterator.Page]: A + generator of page instances. + + raises: + ValueError: If the iterator has already been started. + """ + if self._started: + raise ValueError("Iterator has already started", self) + self._started = True + return self._page_aiter(increment=True) + + async def _items_aiter(self): + """Iterator for each item returned.""" + async for page in self._page_aiter(increment=False): + for item in page: + self.num_results += 1 + yield item + + def __aiter__(self): + """Iterator for each item returned. + + Returns: + types.GeneratorType[Any]: A generator of items from the API. + + Raises: + ValueError: If the iterator has already been started. + """ + if self._started: + raise ValueError("Iterator has already started", self) + self._started = True + return self._items_aiter() + + async def _page_aiter(self, increment): + """Generator of pages of API responses. + + Args: + increment (bool): Flag indicating if the total number of results + should be incremented on each page. This is useful since a page + iterator will want to increment by results per page while an + items iterator will want to increment per item. + + Yields: + Page: each page of items from the API. + """ + page = await self._next_page() + while page is not None: + self.page_number += 1 + if increment: + self.num_results += page.num_items + yield page + page = await self._next_page() + + @abc.abstractmethod + async def _next_page(self): + """Get the next page in the iterator. + + This does nothing and is intended to be over-ridden by subclasses + to return the next :class:`Page`. + + Raises: + NotImplementedError: Always, this method is abstract. + """ + raise NotImplementedError + + +class AsyncGRPCIterator(AsyncIterator): + """A generic class for iterating through gRPC list responses. + + .. note:: The class does not take a ``page_token`` argument because it can + just be specified in the ``request``. + + Args: + client (google.cloud.client.Client): The API client. This unused by + this class, but kept to satisfy the :class:`Iterator` interface. + method (Callable[protobuf.Message]): A bound gRPC method that should + take a single message for the request. + request (protobuf.Message): The request message. + items_field (str): The field in the response message that has the + items for the page. + item_to_value (Callable[GRPCIterator, Any]): Callable to convert an + item from the type in the JSON response into a native object. Will + be called with the iterator and a single item. + request_token_field (str): The field in the request message used to + specify the page token. + response_token_field (str): The field in the response message that has + the token for the next page. + max_results (int): The maximum number of results to fetch. + + .. autoattribute:: pages + """ + + _DEFAULT_REQUEST_TOKEN_FIELD = "page_token" + _DEFAULT_RESPONSE_TOKEN_FIELD = "next_page_token" + + def __init__( + self, + client, + method, + request, + items_field, + item_to_value=_item_to_value_identity, + request_token_field=_DEFAULT_REQUEST_TOKEN_FIELD, + response_token_field=_DEFAULT_RESPONSE_TOKEN_FIELD, + max_results=None, + ): + super().__init__(client, item_to_value, max_results=max_results) + self._method = method + self._request = request + self._items_field = items_field + self._request_token_field = request_token_field + self._response_token_field = response_token_field + + async def _next_page(self): + """Get the next page in the iterator. + + Returns: + Page: The next page in the iterator or :data:`None` if + there are no pages left. + """ + if not self._has_next_page(): + return None + + if self.next_page_token is not None: + setattr(self._request, self._request_token_field, self.next_page_token) + + response = await self._method(self._request) + + self.next_page_token = getattr(response, self._response_token_field) + items = getattr(response, self._items_field) + page = Page(self, items, self.item_to_value, raw_page=response) + + return page + + def _has_next_page(self): + """Determines whether or not there are more pages with results. + + Returns: + bool: Whether the iterator has more pages. + """ + if self.page_number == 0: + return True + + # Note: intentionally a falsy check instead of a None check. The RPC + # can return an empty string indicating no more pages. + if self.max_results is not None: + if self.num_results >= self.max_results: + return False + + return True if self.next_page_token else False diff --git a/tests/asyncio/future/__init__.py b/tests/asyncio/future/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/asyncio/future/test_async_future.py b/tests/asyncio/future/test_async_future.py new file mode 100644 index 00000000..3322cb05 --- /dev/null +++ b/tests/asyncio/future/test_async_future.py @@ -0,0 +1,229 @@ +# Copyright 2017, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import mock +import pytest + +from google.api_core import exceptions +from google.api_core.future import async_future + + +class AsyncFuture(async_future.AsyncFuture): + async def done(self): + return False + + async def cancel(self): + return True + + async def cancelled(self): + return False + + async def running(self): + return True + + +@pytest.mark.asyncio +async def test_polling_future_constructor(): + future = AsyncFuture() + assert not await future.done() + assert not await future.cancelled() + assert await future.running() + assert await future.cancel() + + +@pytest.mark.asyncio +async def test_set_result(): + future = AsyncFuture() + callback = mock.Mock() + + future.set_result(1) + + assert await future.result() == 1 + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + + future.add_done_callback(callback) + await callback_called.wait() + + +@pytest.mark.asyncio +async def test_set_exception(): + future = AsyncFuture() + exception = ValueError("meep") + + future.set_exception(exception) + + assert await future.exception() == exception + with pytest.raises(ValueError): + await future.result() + + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + + future.add_done_callback(callback) + await callback_called.wait() + + +@pytest.mark.asyncio +async def test_invoke_callback_exception(): + future = AsyncFuture() + future.set_result(42) + + # This should not raise, despite the callback causing an exception. + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + raise ValueError() + + future.add_done_callback(callback) + await callback_called.wait() + + +class AsyncFutureWithPoll(AsyncFuture): + def __init__(self): + super().__init__() + self.poll_count = 0 + self.event = asyncio.Event() + + async def done(self): + self.poll_count += 1 + await self.event.wait() + self.set_result(42) + return True + + +@pytest.mark.asyncio +async def test_result_with_polling(): + future = AsyncFutureWithPoll() + + future.event.set() + result = await future.result() + + assert result == 42 + assert future.poll_count == 1 + # Repeated calls should not cause additional polling + assert await future.result() == result + assert future.poll_count == 1 + + +class AsyncFutureTimeout(AsyncFutureWithPoll): + + async def done(self): + await asyncio.sleep(0.2) + return False + + +@pytest.mark.asyncio +async def test_result_timeout(): + future = AsyncFutureTimeout() + with pytest.raises(asyncio.TimeoutError): + await future.result(timeout=0.2) + + +@pytest.mark.asyncio +async def test_exception_timeout(): + future = AsyncFutureTimeout() + with pytest.raises(asyncio.TimeoutError): + await future.exception(timeout=0.2) + + +@pytest.mark.asyncio +async def test_result_timeout_with_retry(): + future = AsyncFutureTimeout() + with pytest.raises(asyncio.TimeoutError): + await future.exception(timeout=0.4) + + +class AsyncFutureTransient(AsyncFutureWithPoll): + def __init__(self, errors): + super().__init__() + self._errors = errors + + async def done(self): + if self._errors: + error, self._errors = self._errors[0], self._errors[1:] + raise error("testing") + self.poll_count += 1 + self.set_result(42) + return True + + +@mock.patch("asyncio.sleep", autospec=True) +@pytest.mark.asyncio +async def test_result_transient_error(unused_sleep): + future = AsyncFutureTransient( + ( + exceptions.TooManyRequests, + exceptions.InternalServerError, + exceptions.BadGateway, + ) + ) + result = await future.result() + assert result == 42 + assert future.poll_count == 1 + # Repeated calls should not cause additional polling + assert await future.result() == result + assert future.poll_count == 1 + + +@pytest.mark.asyncio +async def test_callback_concurrency(): + future = AsyncFutureWithPoll() + + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + + future.add_done_callback(callback) + + # Give the thread a second to poll + await asyncio.sleep(1) + assert future.poll_count == 1 + + future.event.set() + await callback_called.wait() + + +@pytest.mark.asyncio +async def test_double_callback_concurrency(): + future = AsyncFutureWithPoll() + + callback_called = asyncio.Event() + + def callback(unused_future): + callback_called.set() + + callback_called2 = asyncio.Event() + + def callback2(unused_future): + callback_called2.set() + + future.add_done_callback(callback) + future.add_done_callback(callback2) + + # Give the thread a second to poll + await asyncio.sleep(1) + future.event.set() + + assert future.poll_count == 1 + await callback_called.wait() + await callback_called2.wait() diff --git a/tests/asyncio/test_operation_async.py b/tests/asyncio/test_operation_async.py new file mode 100644 index 00000000..419749f3 --- /dev/null +++ b/tests/asyncio/test_operation_async.py @@ -0,0 +1,193 @@ +# Copyright 2017, Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import mock +import pytest + +from google.api_core import exceptions +from google.api_core import operation_async +from google.api_core import operations_v1 +from google.api_core import retry_async +from google.longrunning import operations_pb2 +from google.protobuf import struct_pb2 +from google.rpc import code_pb2 +from google.rpc import status_pb2 + +TEST_OPERATION_NAME = "test/operation" + + +def make_operation_proto( + name=TEST_OPERATION_NAME, metadata=None, response=None, error=None, **kwargs +): + operation_proto = operations_pb2.Operation(name=name, **kwargs) + + if metadata is not None: + operation_proto.metadata.Pack(metadata) + + if response is not None: + operation_proto.response.Pack(response) + + if error is not None: + operation_proto.error.CopyFrom(error) + + return operation_proto + + +def make_operation_future(client_operations_responses=None): + if client_operations_responses is None: + client_operations_responses = [make_operation_proto()] + + refresh = mock.AsyncMock(spec=["__call__"], side_effect=client_operations_responses) + refresh.responses = client_operations_responses + cancel = mock.AsyncMock(spec=["__call__"]) + operation_future = operation_async.AsyncOperation( + client_operations_responses[0], + refresh, + cancel, + result_type=struct_pb2.Struct, + metadata_type=struct_pb2.Struct, + ) + + return operation_future, refresh, cancel + + +@pytest.mark.asyncio +async def test_constructor(): + future, refresh, _ = make_operation_future() + + assert future.operation == refresh.responses[0] + assert future.operation.done is False + assert future.operation.name == TEST_OPERATION_NAME + assert future.metadata is None + assert await future.running() + + +def test_metadata(): + expected_metadata = struct_pb2.Struct() + future, _, _ = make_operation_future( + [make_operation_proto(metadata=expected_metadata)] + ) + + assert future.metadata == expected_metadata + + +@pytest.mark.asyncio +async def test_cancellation(): + responses = [ + make_operation_proto(), + # Second response indicates that the operation was cancelled. + make_operation_proto( + done=True, error=status_pb2.Status(code=code_pb2.CANCELLED) + ), + ] + future, _, cancel = make_operation_future(responses) + + assert await future.cancel() + assert await future.cancelled() + cancel.assert_called_once_with() + + # Cancelling twice should have no effect. + assert not await future.cancel() + cancel.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_result(): + expected_result = struct_pb2.Struct() + responses = [ + make_operation_proto(), + # Second operation response includes the result. + make_operation_proto(done=True, response=expected_result), + ] + future, _, _ = make_operation_future(responses) + + result = await future.result() + + assert result == expected_result + assert await future.done() + + +@pytest.mark.asyncio +async def test_done_w_retry(): + RETRY_PREDICATE = retry_async.if_exception_type(exceptions.TooManyRequests) + test_retry = retry_async.AsyncRetry(predicate=RETRY_PREDICATE) + + expected_result = struct_pb2.Struct() + responses = [ + make_operation_proto(), + # Second operation response includes the result. + make_operation_proto(done=True, response=expected_result), + ] + future, refresh, _ = make_operation_future(responses) + + await future.done(retry=test_retry) + refresh.assert_called_once_with(retry=test_retry) + + +@pytest.mark.asyncio +async def test_exception(): + expected_exception = status_pb2.Status(message="meep") + responses = [ + make_operation_proto(), + # Second operation response includes the error. + make_operation_proto(done=True, error=expected_exception), + ] + future, _, _ = make_operation_future(responses) + + exception = await future.exception() + + assert expected_exception.message in "{!r}".format(exception) + + +@mock.patch("asyncio.sleep", autospec=True) +@pytest.mark.asyncio +async def test_unexpected_result(unused_sleep): + responses = [ + make_operation_proto(), + # Second operation response is done, but has not error or response. + make_operation_proto(done=True), + ] + future, _, _ = make_operation_future(responses) + + exception = await future.exception() + + assert "Unexpected state" in "{!r}".format(exception) + + +def test_from_gapic(): + operation_proto = make_operation_proto(done=True) + operations_client = mock.create_autospec( + operations_v1.OperationsClient, instance=True + ) + + future = operation_async.from_gapic( + operation_proto, + operations_client, + struct_pb2.Struct, + metadata_type=struct_pb2.Struct, + ) + + assert future._result_type == struct_pb2.Struct + assert future._metadata_type == struct_pb2.Struct + assert future.operation.name == TEST_OPERATION_NAME + assert future.done + + +def test_deserialize(): + op = make_operation_proto(name="foobarbaz") + serialized = op.SerializeToString() + deserialized_op = operation_async.AsyncOperation.deserialize(serialized) + assert op.name == deserialized_op.name + assert type(op) is type(deserialized_op) diff --git a/tests/asyncio/test_page_iterator_async.py b/tests/asyncio/test_page_iterator_async.py new file mode 100644 index 00000000..42fac2a2 --- /dev/null +++ b/tests/asyncio/test_page_iterator_async.py @@ -0,0 +1,261 @@ +# Copyright 2015 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import mock +import pytest + +from google.api_core import page_iterator_async + + +class PageAsyncIteratorImpl(page_iterator_async.AsyncIterator): + + async def _next_page(self): + return mock.create_autospec(page_iterator_async.Page, instance=True) + + +class TestAsyncIterator: + + def test_constructor(self): + client = mock.sentinel.client + item_to_value = mock.sentinel.item_to_value + token = "ab13nceor03" + max_results = 1337 + + iterator = PageAsyncIteratorImpl( + client, item_to_value, page_token=token, max_results=max_results + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.item_to_value == item_to_value + assert iterator.max_results == max_results + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token == token + assert iterator.num_results == 0 + + def test_pages_property_starts(self): + iterator = PageAsyncIteratorImpl(None, None) + + assert not iterator._started + + assert inspect.isasyncgen(iterator.pages) + + assert iterator._started + + def test_pages_property_restart(self): + iterator = PageAsyncIteratorImpl(None, None) + + assert iterator.pages + + # Make sure we cannot restart. + with pytest.raises(ValueError): + assert iterator.pages + + @pytest.mark.asyncio + async def test__page_aiter_increment(self): + iterator = PageAsyncIteratorImpl(None, None) + page = page_iterator_async.Page( + iterator, ("item",), page_iterator_async._item_to_value_identity) + iterator._next_page = mock.AsyncMock(side_effect=[page, None]) + + assert iterator.num_results == 0 + + page_aiter = iterator._page_aiter(increment=True) + await page_aiter.__anext__() + + assert iterator.num_results == 1 + + @pytest.mark.asyncio + async def test__page_aiter_no_increment(self): + iterator = PageAsyncIteratorImpl(None, None) + + assert iterator.num_results == 0 + + page_aiter = iterator._page_aiter(increment=False) + await page_aiter.__anext__() + + # results should still be 0 after fetching a page. + assert iterator.num_results == 0 + + @pytest.mark.asyncio + async def test__items_aiter(self): + # Items to be returned. + item1 = 17 + item2 = 100 + item3 = 211 + + # Make pages from mock responses + parent = mock.sentinel.parent + page1 = page_iterator_async.Page( + parent, (item1, item2), page_iterator_async._item_to_value_identity) + page2 = page_iterator_async.Page( + parent, (item3,), page_iterator_async._item_to_value_identity) + + iterator = PageAsyncIteratorImpl(None, None) + iterator._next_page = mock.AsyncMock(side_effect=[page1, page2, None]) + + items_aiter = iterator._items_aiter() + + assert inspect.isasyncgen(items_aiter) + + # Consume items and check the state of the iterator. + assert iterator.num_results == 0 + assert await items_aiter.__anext__() == item1 + assert iterator.num_results == 1 + + assert await items_aiter.__anext__() == item2 + assert iterator.num_results == 2 + + assert await items_aiter.__anext__() == item3 + assert iterator.num_results == 3 + + with pytest.raises(StopAsyncIteration): + await items_aiter.__anext__() + + @pytest.mark.asyncio + async def test___aiter__(self): + async_iterator = PageAsyncIteratorImpl(None, None) + async_iterator._next_page = mock.AsyncMock(side_effect=[(1, 2), (3,), None]) + + assert not async_iterator._started + + result = [] + async for item in async_iterator: + result.append(item) + + assert result == [1, 2, 3] + assert async_iterator._started + + def test___aiter__restart(self): + iterator = PageAsyncIteratorImpl(None, None) + + iterator.__aiter__() + + # Make sure we cannot restart. + with pytest.raises(ValueError): + iterator.__aiter__() + + def test___aiter___restart_after_page(self): + iterator = PageAsyncIteratorImpl(None, None) + + assert iterator.pages + + # Make sure we cannot restart after starting the page iterator + with pytest.raises(ValueError): + iterator.__aiter__() + + +class TestAsyncGRPCIterator(object): + + def test_constructor(self): + client = mock.sentinel.client + items_field = "items" + iterator = page_iterator_async.AsyncGRPCIterator( + client, mock.sentinel.method, mock.sentinel.request, items_field + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.max_results is None + assert iterator.item_to_value is page_iterator_async._item_to_value_identity + assert iterator._method == mock.sentinel.method + assert iterator._request == mock.sentinel.request + assert iterator._items_field == items_field + assert ( + iterator._request_token_field + == page_iterator_async.AsyncGRPCIterator._DEFAULT_REQUEST_TOKEN_FIELD + ) + assert ( + iterator._response_token_field + == page_iterator_async.AsyncGRPCIterator._DEFAULT_RESPONSE_TOKEN_FIELD + ) + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token is None + assert iterator.num_results == 0 + + def test_constructor_options(self): + client = mock.sentinel.client + items_field = "items" + request_field = "request" + response_field = "response" + iterator = page_iterator_async.AsyncGRPCIterator( + client, + mock.sentinel.method, + mock.sentinel.request, + items_field, + item_to_value=mock.sentinel.item_to_value, + request_token_field=request_field, + response_token_field=response_field, + max_results=42, + ) + + assert iterator.client is client + assert iterator.max_results == 42 + assert iterator.item_to_value is mock.sentinel.item_to_value + assert iterator._method == mock.sentinel.method + assert iterator._request == mock.sentinel.request + assert iterator._items_field == items_field + assert iterator._request_token_field == request_field + assert iterator._response_token_field == response_field + + @pytest.mark.asyncio + async def test_iterate(self): + request = mock.Mock(spec=["page_token"], page_token=None) + response1 = mock.Mock(items=["a", "b"], next_page_token="1") + response2 = mock.Mock(items=["c"], next_page_token="2") + response3 = mock.Mock(items=["d"], next_page_token="") + method = mock.AsyncMock(side_effect=[response1, response2, response3]) + iterator = page_iterator_async.AsyncGRPCIterator( + mock.sentinel.client, method, request, "items" + ) + + assert iterator.num_results == 0 + + items = [] + async for item in iterator: + items.append(item) + + assert items == ["a", "b", "c", "d"] + + method.assert_called_with(request) + assert method.call_count == 3 + assert request.page_token == "2" + + @pytest.mark.asyncio + async def test_iterate_with_max_results(self): + request = mock.Mock(spec=["page_token"], page_token=None) + response1 = mock.Mock(items=["a", "b"], next_page_token="1") + response2 = mock.Mock(items=["c"], next_page_token="2") + response3 = mock.Mock(items=["d"], next_page_token="") + method = mock.AsyncMock(side_effect=[response1, response2, response3]) + iterator = page_iterator_async.AsyncGRPCIterator( + mock.sentinel.client, method, request, "items", max_results=3 + ) + + assert iterator.num_results == 0 + + items = [] + async for item in iterator: + items.append(item) + + assert items == ["a", "b", "c"] + assert iterator.num_results == 3 + + method.assert_called_with(request) + assert method.call_count == 2 + assert request.page_token == "1"