From 72e3e8b955690b5f180af89a0a15a8870fd556a8 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Fri, 15 Oct 2021 18:08:11 +0000 Subject: [PATCH] feat: add TPU v2alpha1 (#55) - [ ] Regenerate this pull request now. Committer: @rosbo PiperOrigin-RevId: 403400668 Source-Link: https://github.com/googleapis/googleapis/commit/8f48b9778f9f875ac1931acccb1ff7bada71a372 Source-Link: https://github.com/googleapis/googleapis-gen/commit/f966fd02bd0d248e9adeffe3bb2daa53bf44a252 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiZjk2NmZkMDJiZDBkMjQ4ZTlhZGVmZmUzYmIyZGFhNTNiZjQ0YTI1MiJ9 --- docs/index.rst | 11 + docs/tpu_v2alpha1/services.rst | 6 + docs/tpu_v2alpha1/tpu.rst | 10 + docs/tpu_v2alpha1/types.rst | 7 + google/cloud/tpu_v2alpha1/__init__.py | 90 + google/cloud/tpu_v2alpha1/gapic_metadata.json | 153 + google/cloud/tpu_v2alpha1/py.typed | 2 + .../cloud/tpu_v2alpha1/services/__init__.py | 15 + .../tpu_v2alpha1/services/tpu/__init__.py | 22 + .../tpu_v2alpha1/services/tpu/async_client.py | 1108 +++++ .../cloud/tpu_v2alpha1/services/tpu/client.py | 1345 ++++++ .../cloud/tpu_v2alpha1/services/tpu/pagers.py | 411 ++ .../services/tpu/transports/__init__.py | 33 + .../services/tpu/transports/base.py | 351 ++ .../services/tpu/transports/grpc.py | 597 +++ .../services/tpu/transports/grpc_asyncio.py | 611 +++ google/cloud/tpu_v2alpha1/types/__init__.py | 86 + google/cloud/tpu_v2alpha1/types/cloud_tpu.py | 766 ++++ scripts/fixup_tpu_v2alpha1_keywords.py | 188 + tests/unit/gapic/tpu_v2alpha1/__init__.py | 15 + tests/unit/gapic/tpu_v2alpha1/test_tpu.py | 4026 +++++++++++++++++ 21 files changed, 9853 insertions(+) create mode 100644 docs/tpu_v2alpha1/services.rst create mode 100644 docs/tpu_v2alpha1/tpu.rst create mode 100644 docs/tpu_v2alpha1/types.rst create mode 100644 google/cloud/tpu_v2alpha1/__init__.py create mode 100644 google/cloud/tpu_v2alpha1/gapic_metadata.json create mode 100644 google/cloud/tpu_v2alpha1/py.typed create mode 100644 google/cloud/tpu_v2alpha1/services/__init__.py create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/__init__.py create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/async_client.py create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/client.py create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/pagers.py create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/transports/__init__.py create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/transports/base.py create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/transports/grpc.py create mode 100644 google/cloud/tpu_v2alpha1/services/tpu/transports/grpc_asyncio.py create mode 100644 google/cloud/tpu_v2alpha1/types/__init__.py create mode 100644 google/cloud/tpu_v2alpha1/types/cloud_tpu.py create mode 100644 scripts/fixup_tpu_v2alpha1_keywords.py create mode 100644 tests/unit/gapic/tpu_v2alpha1/__init__.py create mode 100644 tests/unit/gapic/tpu_v2alpha1/test_tpu.py diff --git a/docs/index.rst b/docs/index.rst index 85bc9d5..aff8054 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -2,6 +2,9 @@ .. include:: multiprocessing.rst +This package includes clients for multiple versions of Cloud TPU. +By default, you will get version ``tpu_v1``. + API Reference ------------- @@ -11,6 +14,14 @@ API Reference tpu_v1/services tpu_v1/types +API Reference +------------- +.. toctree:: + :maxdepth: 2 + + tpu_v2alpha1/services + tpu_v2alpha1/types + Changelog --------- diff --git a/docs/tpu_v2alpha1/services.rst b/docs/tpu_v2alpha1/services.rst new file mode 100644 index 0000000..74c3c78 --- /dev/null +++ b/docs/tpu_v2alpha1/services.rst @@ -0,0 +1,6 @@ +Services for Google Cloud Tpu v2alpha1 API +========================================== +.. toctree:: + :maxdepth: 2 + + tpu diff --git a/docs/tpu_v2alpha1/tpu.rst b/docs/tpu_v2alpha1/tpu.rst new file mode 100644 index 0000000..9b3906b --- /dev/null +++ b/docs/tpu_v2alpha1/tpu.rst @@ -0,0 +1,10 @@ +Tpu +--------------------- + +.. automodule:: google.cloud.tpu_v2alpha1.services.tpu + :members: + :inherited-members: + +.. automodule:: google.cloud.tpu_v2alpha1.services.tpu.pagers + :members: + :inherited-members: diff --git a/docs/tpu_v2alpha1/types.rst b/docs/tpu_v2alpha1/types.rst new file mode 100644 index 0000000..6c1d0f3 --- /dev/null +++ b/docs/tpu_v2alpha1/types.rst @@ -0,0 +1,7 @@ +Types for Google Cloud Tpu v2alpha1 API +======================================= + +.. automodule:: google.cloud.tpu_v2alpha1.types + :members: + :undoc-members: + :show-inheritance: diff --git a/google/cloud/tpu_v2alpha1/__init__.py b/google/cloud/tpu_v2alpha1/__init__.py new file mode 100644 index 0000000..9ffc130 --- /dev/null +++ b/google/cloud/tpu_v2alpha1/__init__.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .services.tpu import TpuClient +from .services.tpu import TpuAsyncClient + +from .types.cloud_tpu import AcceleratorType +from .types.cloud_tpu import AccessConfig +from .types.cloud_tpu import AttachedDisk +from .types.cloud_tpu import CreateNodeRequest +from .types.cloud_tpu import DeleteNodeRequest +from .types.cloud_tpu import GenerateServiceIdentityRequest +from .types.cloud_tpu import GenerateServiceIdentityResponse +from .types.cloud_tpu import GetAcceleratorTypeRequest +from .types.cloud_tpu import GetGuestAttributesRequest +from .types.cloud_tpu import GetGuestAttributesResponse +from .types.cloud_tpu import GetNodeRequest +from .types.cloud_tpu import GetRuntimeVersionRequest +from .types.cloud_tpu import GuestAttributes +from .types.cloud_tpu import GuestAttributesEntry +from .types.cloud_tpu import GuestAttributesValue +from .types.cloud_tpu import ListAcceleratorTypesRequest +from .types.cloud_tpu import ListAcceleratorTypesResponse +from .types.cloud_tpu import ListNodesRequest +from .types.cloud_tpu import ListNodesResponse +from .types.cloud_tpu import ListRuntimeVersionsRequest +from .types.cloud_tpu import ListRuntimeVersionsResponse +from .types.cloud_tpu import NetworkConfig +from .types.cloud_tpu import NetworkEndpoint +from .types.cloud_tpu import Node +from .types.cloud_tpu import OperationMetadata +from .types.cloud_tpu import RuntimeVersion +from .types.cloud_tpu import SchedulingConfig +from .types.cloud_tpu import ServiceAccount +from .types.cloud_tpu import ServiceIdentity +from .types.cloud_tpu import StartNodeRequest +from .types.cloud_tpu import StopNodeRequest +from .types.cloud_tpu import Symptom +from .types.cloud_tpu import UpdateNodeRequest + +__all__ = ( + "TpuAsyncClient", + "AcceleratorType", + "AccessConfig", + "AttachedDisk", + "CreateNodeRequest", + "DeleteNodeRequest", + "GenerateServiceIdentityRequest", + "GenerateServiceIdentityResponse", + "GetAcceleratorTypeRequest", + "GetGuestAttributesRequest", + "GetGuestAttributesResponse", + "GetNodeRequest", + "GetRuntimeVersionRequest", + "GuestAttributes", + "GuestAttributesEntry", + "GuestAttributesValue", + "ListAcceleratorTypesRequest", + "ListAcceleratorTypesResponse", + "ListNodesRequest", + "ListNodesResponse", + "ListRuntimeVersionsRequest", + "ListRuntimeVersionsResponse", + "NetworkConfig", + "NetworkEndpoint", + "Node", + "OperationMetadata", + "RuntimeVersion", + "SchedulingConfig", + "ServiceAccount", + "ServiceIdentity", + "StartNodeRequest", + "StopNodeRequest", + "Symptom", + "TpuClient", + "UpdateNodeRequest", +) diff --git a/google/cloud/tpu_v2alpha1/gapic_metadata.json b/google/cloud/tpu_v2alpha1/gapic_metadata.json new file mode 100644 index 0000000..0d306ce --- /dev/null +++ b/google/cloud/tpu_v2alpha1/gapic_metadata.json @@ -0,0 +1,153 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.cloud.tpu_v2alpha1", + "protoPackage": "google.cloud.tpu.v2alpha1", + "schema": "1.0", + "services": { + "Tpu": { + "clients": { + "grpc": { + "libraryClient": "TpuClient", + "rpcs": { + "CreateNode": { + "methods": [ + "create_node" + ] + }, + "DeleteNode": { + "methods": [ + "delete_node" + ] + }, + "GenerateServiceIdentity": { + "methods": [ + "generate_service_identity" + ] + }, + "GetAcceleratorType": { + "methods": [ + "get_accelerator_type" + ] + }, + "GetGuestAttributes": { + "methods": [ + "get_guest_attributes" + ] + }, + "GetNode": { + "methods": [ + "get_node" + ] + }, + "GetRuntimeVersion": { + "methods": [ + "get_runtime_version" + ] + }, + "ListAcceleratorTypes": { + "methods": [ + "list_accelerator_types" + ] + }, + "ListNodes": { + "methods": [ + "list_nodes" + ] + }, + "ListRuntimeVersions": { + "methods": [ + "list_runtime_versions" + ] + }, + "StartNode": { + "methods": [ + "start_node" + ] + }, + "StopNode": { + "methods": [ + "stop_node" + ] + }, + "UpdateNode": { + "methods": [ + "update_node" + ] + } + } + }, + "grpc-async": { + "libraryClient": "TpuAsyncClient", + "rpcs": { + "CreateNode": { + "methods": [ + "create_node" + ] + }, + "DeleteNode": { + "methods": [ + "delete_node" + ] + }, + "GenerateServiceIdentity": { + "methods": [ + "generate_service_identity" + ] + }, + "GetAcceleratorType": { + "methods": [ + "get_accelerator_type" + ] + }, + "GetGuestAttributes": { + "methods": [ + "get_guest_attributes" + ] + }, + "GetNode": { + "methods": [ + "get_node" + ] + }, + "GetRuntimeVersion": { + "methods": [ + "get_runtime_version" + ] + }, + "ListAcceleratorTypes": { + "methods": [ + "list_accelerator_types" + ] + }, + "ListNodes": { + "methods": [ + "list_nodes" + ] + }, + "ListRuntimeVersions": { + "methods": [ + "list_runtime_versions" + ] + }, + "StartNode": { + "methods": [ + "start_node" + ] + }, + "StopNode": { + "methods": [ + "stop_node" + ] + }, + "UpdateNode": { + "methods": [ + "update_node" + ] + } + } + } + } + } + } +} diff --git a/google/cloud/tpu_v2alpha1/py.typed b/google/cloud/tpu_v2alpha1/py.typed new file mode 100644 index 0000000..e122051 --- /dev/null +++ b/google/cloud/tpu_v2alpha1/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-tpu package uses inline types. diff --git a/google/cloud/tpu_v2alpha1/services/__init__.py b/google/cloud/tpu_v2alpha1/services/__init__.py new file mode 100644 index 0000000..4de6597 --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/google/cloud/tpu_v2alpha1/services/tpu/__init__.py b/google/cloud/tpu_v2alpha1/services/tpu/__init__.py new file mode 100644 index 0000000..d9a7a94 --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/tpu/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import TpuClient +from .async_client import TpuAsyncClient + +__all__ = ( + "TpuClient", + "TpuAsyncClient", +) diff --git a/google/cloud/tpu_v2alpha1/services/tpu/async_client.py b/google/cloud/tpu_v2alpha1/services/tpu/async_client.py new file mode 100644 index 0000000..fdd1076 --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/tpu/async_client.py @@ -0,0 +1,1108 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.tpu_v2alpha1.services.tpu import pagers +from google.cloud.tpu_v2alpha1.types import cloud_tpu +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from .transports.base import TpuTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import TpuGrpcAsyncIOTransport +from .client import TpuClient + + +class TpuAsyncClient: + """Manages TPU nodes and other resources + TPU API v2alpha1 + """ + + _client: TpuClient + + DEFAULT_ENDPOINT = TpuClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = TpuClient.DEFAULT_MTLS_ENDPOINT + + accelerator_type_path = staticmethod(TpuClient.accelerator_type_path) + parse_accelerator_type_path = staticmethod(TpuClient.parse_accelerator_type_path) + node_path = staticmethod(TpuClient.node_path) + parse_node_path = staticmethod(TpuClient.parse_node_path) + runtime_version_path = staticmethod(TpuClient.runtime_version_path) + parse_runtime_version_path = staticmethod(TpuClient.parse_runtime_version_path) + common_billing_account_path = staticmethod(TpuClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod( + TpuClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(TpuClient.common_folder_path) + parse_common_folder_path = staticmethod(TpuClient.parse_common_folder_path) + common_organization_path = staticmethod(TpuClient.common_organization_path) + parse_common_organization_path = staticmethod( + TpuClient.parse_common_organization_path + ) + common_project_path = staticmethod(TpuClient.common_project_path) + parse_common_project_path = staticmethod(TpuClient.parse_common_project_path) + common_location_path = staticmethod(TpuClient.common_location_path) + parse_common_location_path = staticmethod(TpuClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TpuAsyncClient: The constructed client. + """ + return TpuClient.from_service_account_info.__func__(TpuAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TpuAsyncClient: The constructed client. + """ + return TpuClient.from_service_account_file.__func__(TpuAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> TpuTransport: + """Returns the transport used by the client instance. + + Returns: + TpuTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(TpuClient).get_transport_class, type(TpuClient) + ) + + def __init__( + self, + *, + credentials: ga_credentials.Credentials = None, + transport: Union[str, TpuTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the tpu client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.TpuTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = TpuClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def list_nodes( + self, + request: cloud_tpu.ListNodesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListNodesAsyncPager: + r"""Lists nodes. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.ListNodesRequest`): + The request object. Request for + [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes]. + parent (:class:`str`): + Required. The parent resource name. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.services.tpu.pagers.ListNodesAsyncPager: + Response for + [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.ListNodesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_nodes, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListNodesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_node( + self, + request: cloud_tpu.GetNodeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.Node: + r"""Gets the details of a node. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.GetNodeRequest`): + The request object. Request for + [GetNode][google.cloud.tpu.v2alpha1.Tpu.GetNode]. + name (:class:`str`): + Required. The resource name. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.Node: + A TPU instance. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.GetNodeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_node, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def create_node( + self, + request: cloud_tpu.CreateNodeRequest = None, + *, + parent: str = None, + node: cloud_tpu.Node = None, + node_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a node. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.CreateNodeRequest`): + The request object. Request for + [CreateNode][google.cloud.tpu.v2alpha1.Tpu.CreateNode]. + parent (:class:`str`): + Required. The parent resource name. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + node (:class:`google.cloud.tpu_v2alpha1.types.Node`): + Required. The node. + This corresponds to the ``node`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + node_id (:class:`str`): + The unqualified resource name. + This corresponds to the ``node_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, node, node_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.CreateNodeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if node is not None: + request.node = node + if node_id is not None: + request.node_id = node_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_node, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + async def delete_node( + self, + request: cloud_tpu.DeleteNodeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a node. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.DeleteNodeRequest`): + The request object. Request for + [DeleteNode][google.cloud.tpu.v2alpha1.Tpu.DeleteNode]. + name (:class:`str`): + Required. The resource name. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.DeleteNodeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_node, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + async def stop_node( + self, + request: cloud_tpu.StopNodeRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Stops a node. This operation is only available with + single TPU nodes. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.StopNodeRequest`): + The request object. Request for + [StopNode][google.cloud.tpu.v2alpha1.Tpu.StopNode]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + request = cloud_tpu.StopNodeRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.stop_node, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + async def start_node( + self, + request: cloud_tpu.StartNodeRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Starts a node. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.StartNodeRequest`): + The request object. Request for + [StartNode][google.cloud.tpu.v2alpha1.Tpu.StartNode]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + request = cloud_tpu.StartNodeRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.start_node, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + async def update_node( + self, + request: cloud_tpu.UpdateNodeRequest = None, + *, + node: cloud_tpu.Node = None, + update_mask: field_mask_pb2.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates the configurations of a node. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.UpdateNodeRequest`): + The request object. Request for + [UpdateNode][google.cloud.tpu.v2alpha1.Tpu.UpdateNode]. + node (:class:`google.cloud.tpu_v2alpha1.types.Node`): + Required. The node. Only fields specified in update_mask + are updated. + + This corresponds to the ``node`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. Mask of fields from [Node][Tpu.Node] to + update. Supported fields: None. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([node, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.UpdateNodeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if node is not None: + request.node = node + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_node, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("node.name", request.node.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + async def generate_service_identity( + self, + request: cloud_tpu.GenerateServiceIdentityRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.GenerateServiceIdentityResponse: + r"""Generates the Cloud TPU service identity for the + project. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.GenerateServiceIdentityRequest`): + The request object. Request for + [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.GenerateServiceIdentityResponse: + Response for + [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity]. + + """ + # Create or coerce a protobuf request object. + request = cloud_tpu.GenerateServiceIdentityRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.generate_service_identity, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_accelerator_types( + self, + request: cloud_tpu.ListAcceleratorTypesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAcceleratorTypesAsyncPager: + r"""Lists accelerator types supported by this API. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesRequest`): + The request object. Request for + [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes]. + parent (:class:`str`): + Required. The parent resource name. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.services.tpu.pagers.ListAcceleratorTypesAsyncPager: + Response for + [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.ListAcceleratorTypesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_accelerator_types, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListAcceleratorTypesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_accelerator_type( + self, + request: cloud_tpu.GetAcceleratorTypeRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.AcceleratorType: + r"""Gets AcceleratorType. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.GetAcceleratorTypeRequest`): + The request object. Request for + [GetAcceleratorType][google.cloud.tpu.v2alpha1.Tpu.GetAcceleratorType]. + name (:class:`str`): + Required. The resource name. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.AcceleratorType: + A accelerator type that a Node can be + configured with. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.GetAcceleratorTypeRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_accelerator_type, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_runtime_versions( + self, + request: cloud_tpu.ListRuntimeVersionsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListRuntimeVersionsAsyncPager: + r"""Lists runtime versions supported by this API. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsRequest`): + The request object. Request for + [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions]. + parent (:class:`str`): + Required. The parent resource name. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.services.tpu.pagers.ListRuntimeVersionsAsyncPager: + Response for + [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.ListRuntimeVersionsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_runtime_versions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListRuntimeVersionsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_runtime_version( + self, + request: cloud_tpu.GetRuntimeVersionRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.RuntimeVersion: + r"""Gets a runtime version. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.GetRuntimeVersionRequest`): + The request object. Request for + [GetRuntimeVersion][google.cloud.tpu.v2alpha1.Tpu.GetRuntimeVersion]. + name (:class:`str`): + Required. The resource name. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.RuntimeVersion: + A runtime version that a Node can be + configured with. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = cloud_tpu.GetRuntimeVersionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_runtime_version, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_guest_attributes( + self, + request: cloud_tpu.GetGuestAttributesRequest = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.GetGuestAttributesResponse: + r"""Retrieves the guest attributes for the node. + + Args: + request (:class:`google.cloud.tpu_v2alpha1.types.GetGuestAttributesRequest`): + The request object. Request for + [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.GetGuestAttributesResponse: + Response for + [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes]. + + """ + # Create or coerce a protobuf request object. + request = cloud_tpu.GetGuestAttributesRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_guest_attributes, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution("google-cloud-tpu",).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("TpuAsyncClient",) diff --git a/google/cloud/tpu_v2alpha1/services/tpu/client.py b/google/cloud/tpu_v2alpha1/services/tpu/client.py new file mode 100644 index 0000000..1c76efc --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/tpu/client.py @@ -0,0 +1,1345 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.tpu_v2alpha1.services.tpu import pagers +from google.cloud.tpu_v2alpha1.types import cloud_tpu +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from .transports.base import TpuTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import TpuGrpcTransport +from .transports.grpc_asyncio import TpuGrpcAsyncIOTransport + + +class TpuClientMeta(type): + """Metaclass for the Tpu client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[TpuTransport]] + _transport_registry["grpc"] = TpuGrpcTransport + _transport_registry["grpc_asyncio"] = TpuGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[TpuTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class TpuClient(metaclass=TpuClientMeta): + """Manages TPU nodes and other resources + TPU API v2alpha1 + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "tpu.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TpuClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TpuClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> TpuTransport: + """Returns the transport used by the client instance. + + Returns: + TpuTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def accelerator_type_path( + project: str, location: str, accelerator_type: str, + ) -> str: + """Returns a fully-qualified accelerator_type string.""" + return "projects/{project}/locations/{location}/acceleratorTypes/{accelerator_type}".format( + project=project, location=location, accelerator_type=accelerator_type, + ) + + @staticmethod + def parse_accelerator_type_path(path: str) -> Dict[str, str]: + """Parses a accelerator_type path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/acceleratorTypes/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def node_path(project: str, location: str, node: str,) -> str: + """Returns a fully-qualified node string.""" + return "projects/{project}/locations/{location}/nodes/{node}".format( + project=project, location=location, node=node, + ) + + @staticmethod + def parse_node_path(path: str) -> Dict[str, str]: + """Parses a node path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/nodes/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def runtime_version_path(project: str, location: str, runtime_version: str,) -> str: + """Returns a fully-qualified runtime_version string.""" + return "projects/{project}/locations/{location}/runtimeVersions/{runtime_version}".format( + project=project, location=location, runtime_version=runtime_version, + ) + + @staticmethod + def parse_runtime_version_path(path: str) -> Dict[str, str]: + """Parses a runtime_version path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/runtimeVersions/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, TpuTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the tpu client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, TpuTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + if is_mtls: + client_cert_source_func = mtls.default_client_cert_source() + else: + client_cert_source_func = None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + if is_mtls: + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = self.DEFAULT_ENDPOINT + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted " + "values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, TpuTransport): + # transport is a TpuTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + ) + + def list_nodes( + self, + request: Union[cloud_tpu.ListNodesRequest, dict] = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListNodesPager: + r"""Lists nodes. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.ListNodesRequest, dict]): + The request object. Request for + [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes]. + parent (str): + Required. The parent resource name. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.services.tpu.pagers.ListNodesPager: + Response for + [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.ListNodesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.ListNodesRequest): + request = cloud_tpu.ListNodesRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_nodes] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListNodesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def get_node( + self, + request: Union[cloud_tpu.GetNodeRequest, dict] = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.Node: + r"""Gets the details of a node. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.GetNodeRequest, dict]): + The request object. Request for + [GetNode][google.cloud.tpu.v2alpha1.Tpu.GetNode]. + name (str): + Required. The resource name. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.Node: + A TPU instance. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.GetNodeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.GetNodeRequest): + request = cloud_tpu.GetNodeRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_node] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def create_node( + self, + request: Union[cloud_tpu.CreateNodeRequest, dict] = None, + *, + parent: str = None, + node: cloud_tpu.Node = None, + node_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Creates a node. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.CreateNodeRequest, dict]): + The request object. Request for + [CreateNode][google.cloud.tpu.v2alpha1.Tpu.CreateNode]. + parent (str): + Required. The parent resource name. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + node (google.cloud.tpu_v2alpha1.types.Node): + Required. The node. + This corresponds to the ``node`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + node_id (str): + The unqualified resource name. + This corresponds to the ``node_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, node, node_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.CreateNodeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.CreateNodeRequest): + request = cloud_tpu.CreateNodeRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if node is not None: + request.node = node + if node_id is not None: + request.node_id = node_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_node] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + def delete_node( + self, + request: Union[cloud_tpu.DeleteNodeRequest, dict] = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Deletes a node. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.DeleteNodeRequest, dict]): + The request object. Request for + [DeleteNode][google.cloud.tpu.v2alpha1.Tpu.DeleteNode]. + name (str): + Required. The resource name. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.DeleteNodeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.DeleteNodeRequest): + request = cloud_tpu.DeleteNodeRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_node] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + def stop_node( + self, + request: Union[cloud_tpu.StopNodeRequest, dict] = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Stops a node. This operation is only available with + single TPU nodes. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.StopNodeRequest, dict]): + The request object. Request for + [StopNode][google.cloud.tpu.v2alpha1.Tpu.StopNode]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.StopNodeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.StopNodeRequest): + request = cloud_tpu.StopNodeRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.stop_node] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + def start_node( + self, + request: Union[cloud_tpu.StartNodeRequest, dict] = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Starts a node. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.StartNodeRequest, dict]): + The request object. Request for + [StartNode][google.cloud.tpu.v2alpha1.Tpu.StartNode]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.StartNodeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.StartNodeRequest): + request = cloud_tpu.StartNodeRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.start_node] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + def update_node( + self, + request: Union[cloud_tpu.UpdateNodeRequest, dict] = None, + *, + node: cloud_tpu.Node = None, + update_mask: field_mask_pb2.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Updates the configurations of a node. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.UpdateNodeRequest, dict]): + The request object. Request for + [UpdateNode][google.cloud.tpu.v2alpha1.Tpu.UpdateNode]. + node (google.cloud.tpu_v2alpha1.types.Node): + Required. The node. Only fields specified in update_mask + are updated. + + This corresponds to the ``node`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Mask of fields from [Node][Tpu.Node] to + update. Supported fields: None. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.tpu_v2alpha1.types.Node` A TPU + instance. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([node, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.UpdateNodeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.UpdateNodeRequest): + request = cloud_tpu.UpdateNodeRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if node is not None: + request.node = node + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_node] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("node.name", request.node.name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + cloud_tpu.Node, + metadata_type=cloud_tpu.OperationMetadata, + ) + + # Done; return the response. + return response + + def generate_service_identity( + self, + request: Union[cloud_tpu.GenerateServiceIdentityRequest, dict] = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.GenerateServiceIdentityResponse: + r"""Generates the Cloud TPU service identity for the + project. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.GenerateServiceIdentityRequest, dict]): + The request object. Request for + [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.GenerateServiceIdentityResponse: + Response for + [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity]. + + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.GenerateServiceIdentityRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.GenerateServiceIdentityRequest): + request = cloud_tpu.GenerateServiceIdentityRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.generate_service_identity + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_accelerator_types( + self, + request: Union[cloud_tpu.ListAcceleratorTypesRequest, dict] = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAcceleratorTypesPager: + r"""Lists accelerator types supported by this API. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesRequest, dict]): + The request object. Request for + [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes]. + parent (str): + Required. The parent resource name. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.services.tpu.pagers.ListAcceleratorTypesPager: + Response for + [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.ListAcceleratorTypesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.ListAcceleratorTypesRequest): + request = cloud_tpu.ListAcceleratorTypesRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_accelerator_types] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListAcceleratorTypesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def get_accelerator_type( + self, + request: Union[cloud_tpu.GetAcceleratorTypeRequest, dict] = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.AcceleratorType: + r"""Gets AcceleratorType. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.GetAcceleratorTypeRequest, dict]): + The request object. Request for + [GetAcceleratorType][google.cloud.tpu.v2alpha1.Tpu.GetAcceleratorType]. + name (str): + Required. The resource name. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.AcceleratorType: + A accelerator type that a Node can be + configured with. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.GetAcceleratorTypeRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.GetAcceleratorTypeRequest): + request = cloud_tpu.GetAcceleratorTypeRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_accelerator_type] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_runtime_versions( + self, + request: Union[cloud_tpu.ListRuntimeVersionsRequest, dict] = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListRuntimeVersionsPager: + r"""Lists runtime versions supported by this API. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsRequest, dict]): + The request object. Request for + [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions]. + parent (str): + Required. The parent resource name. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.services.tpu.pagers.ListRuntimeVersionsPager: + Response for + [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions]. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.ListRuntimeVersionsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.ListRuntimeVersionsRequest): + request = cloud_tpu.ListRuntimeVersionsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_runtime_versions] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListRuntimeVersionsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def get_runtime_version( + self, + request: Union[cloud_tpu.GetRuntimeVersionRequest, dict] = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.RuntimeVersion: + r"""Gets a runtime version. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.GetRuntimeVersionRequest, dict]): + The request object. Request for + [GetRuntimeVersion][google.cloud.tpu.v2alpha1.Tpu.GetRuntimeVersion]. + name (str): + Required. The resource name. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.RuntimeVersion: + A runtime version that a Node can be + configured with. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.GetRuntimeVersionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.GetRuntimeVersionRequest): + request = cloud_tpu.GetRuntimeVersionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_runtime_version] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_guest_attributes( + self, + request: Union[cloud_tpu.GetGuestAttributesRequest, dict] = None, + *, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> cloud_tpu.GetGuestAttributesResponse: + r"""Retrieves the guest attributes for the node. + + Args: + request (Union[google.cloud.tpu_v2alpha1.types.GetGuestAttributesRequest, dict]): + The request object. Request for + [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.tpu_v2alpha1.types.GetGuestAttributesResponse: + Response for + [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes]. + + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a cloud_tpu.GetGuestAttributesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, cloud_tpu.GetGuestAttributesRequest): + request = cloud_tpu.GetGuestAttributesRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_guest_attributes] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution("google-cloud-tpu",).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("TpuClient",) diff --git a/google/cloud/tpu_v2alpha1/services/tpu/pagers.py b/google/cloud/tpu_v2alpha1/services/tpu/pagers.py new file mode 100644 index 0000000..c1859a1 --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/tpu/pagers.py @@ -0,0 +1,411 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Sequence, + Tuple, + Optional, + Iterator, +) + +from google.cloud.tpu_v2alpha1.types import cloud_tpu + + +class ListNodesPager: + """A pager for iterating through ``list_nodes`` requests. + + This class thinly wraps an initial + :class:`google.cloud.tpu_v2alpha1.types.ListNodesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``nodes`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListNodes`` requests and continue to iterate + through the ``nodes`` field on the + corresponding responses. + + All the usual :class:`google.cloud.tpu_v2alpha1.types.ListNodesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., cloud_tpu.ListNodesResponse], + request: cloud_tpu.ListNodesRequest, + response: cloud_tpu.ListNodesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.tpu_v2alpha1.types.ListNodesRequest): + The initial request object. + response (google.cloud.tpu_v2alpha1.types.ListNodesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = cloud_tpu.ListNodesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[cloud_tpu.ListNodesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[cloud_tpu.Node]: + for page in self.pages: + yield from page.nodes + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListNodesAsyncPager: + """A pager for iterating through ``list_nodes`` requests. + + This class thinly wraps an initial + :class:`google.cloud.tpu_v2alpha1.types.ListNodesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``nodes`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListNodes`` requests and continue to iterate + through the ``nodes`` field on the + corresponding responses. + + All the usual :class:`google.cloud.tpu_v2alpha1.types.ListNodesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[cloud_tpu.ListNodesResponse]], + request: cloud_tpu.ListNodesRequest, + response: cloud_tpu.ListNodesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.tpu_v2alpha1.types.ListNodesRequest): + The initial request object. + response (google.cloud.tpu_v2alpha1.types.ListNodesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = cloud_tpu.ListNodesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[cloud_tpu.ListNodesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterator[cloud_tpu.Node]: + async def async_generator(): + async for page in self.pages: + for response in page.nodes: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListAcceleratorTypesPager: + """A pager for iterating through ``list_accelerator_types`` requests. + + This class thinly wraps an initial + :class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``accelerator_types`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListAcceleratorTypes`` requests and continue to iterate + through the ``accelerator_types`` field on the + corresponding responses. + + All the usual :class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., cloud_tpu.ListAcceleratorTypesResponse], + request: cloud_tpu.ListAcceleratorTypesRequest, + response: cloud_tpu.ListAcceleratorTypesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesRequest): + The initial request object. + response (google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = cloud_tpu.ListAcceleratorTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[cloud_tpu.ListAcceleratorTypesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[cloud_tpu.AcceleratorType]: + for page in self.pages: + yield from page.accelerator_types + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListAcceleratorTypesAsyncPager: + """A pager for iterating through ``list_accelerator_types`` requests. + + This class thinly wraps an initial + :class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``accelerator_types`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListAcceleratorTypes`` requests and continue to iterate + through the ``accelerator_types`` field on the + corresponding responses. + + All the usual :class:`google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[cloud_tpu.ListAcceleratorTypesResponse]], + request: cloud_tpu.ListAcceleratorTypesRequest, + response: cloud_tpu.ListAcceleratorTypesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesRequest): + The initial request object. + response (google.cloud.tpu_v2alpha1.types.ListAcceleratorTypesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = cloud_tpu.ListAcceleratorTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[cloud_tpu.ListAcceleratorTypesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterator[cloud_tpu.AcceleratorType]: + async def async_generator(): + async for page in self.pages: + for response in page.accelerator_types: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListRuntimeVersionsPager: + """A pager for iterating through ``list_runtime_versions`` requests. + + This class thinly wraps an initial + :class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``runtime_versions`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListRuntimeVersions`` requests and continue to iterate + through the ``runtime_versions`` field on the + corresponding responses. + + All the usual :class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., cloud_tpu.ListRuntimeVersionsResponse], + request: cloud_tpu.ListRuntimeVersionsRequest, + response: cloud_tpu.ListRuntimeVersionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsRequest): + The initial request object. + response (google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = cloud_tpu.ListRuntimeVersionsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[cloud_tpu.ListRuntimeVersionsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[cloud_tpu.RuntimeVersion]: + for page in self.pages: + yield from page.runtime_versions + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListRuntimeVersionsAsyncPager: + """A pager for iterating through ``list_runtime_versions`` requests. + + This class thinly wraps an initial + :class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``runtime_versions`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListRuntimeVersions`` requests and continue to iterate + through the ``runtime_versions`` field on the + corresponding responses. + + All the usual :class:`google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[cloud_tpu.ListRuntimeVersionsResponse]], + request: cloud_tpu.ListRuntimeVersionsRequest, + response: cloud_tpu.ListRuntimeVersionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsRequest): + The initial request object. + response (google.cloud.tpu_v2alpha1.types.ListRuntimeVersionsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = cloud_tpu.ListRuntimeVersionsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[cloud_tpu.ListRuntimeVersionsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterator[cloud_tpu.RuntimeVersion]: + async def async_generator(): + async for page in self.pages: + for response in page.runtime_versions: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/tpu_v2alpha1/services/tpu/transports/__init__.py b/google/cloud/tpu_v2alpha1/services/tpu/transports/__init__.py new file mode 100644 index 0000000..d3ede28 --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/tpu/transports/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import TpuTransport +from .grpc import TpuGrpcTransport +from .grpc_asyncio import TpuGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[TpuTransport]] +_transport_registry["grpc"] = TpuGrpcTransport +_transport_registry["grpc_asyncio"] = TpuGrpcAsyncIOTransport + +__all__ = ( + "TpuTransport", + "TpuGrpcTransport", + "TpuGrpcAsyncIOTransport", +) diff --git a/google/cloud/tpu_v2alpha1/services/tpu/transports/base.py b/google/cloud/tpu_v2alpha1/services/tpu/transports/base.py new file mode 100644 index 0000000..6cc209d --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/tpu/transports/base.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union +import packaging.version +import pkg_resources + +import google.auth # type: ignore +import google.api_core # type: ignore +from google.api_core import exceptions as core_exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.tpu_v2alpha1.types import cloud_tpu +from google.longrunning import operations_pb2 # type: ignore + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution("google-cloud-tpu",).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + +try: + # google.auth.__version__ was added in 1.26.0 + _GOOGLE_AUTH_VERSION = google.auth.__version__ +except AttributeError: + try: # try pkg_resources if it is available + _GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version + except pkg_resources.DistributionNotFound: # pragma: NO COVER + _GOOGLE_AUTH_VERSION = None + + +class TpuTransport(abc.ABC): + """Abstract transport class for Tpu.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + DEFAULT_HOST: str = "tpu.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: ga_credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + scopes_kwargs = self._get_scopes_kwargs(self._host, scopes) + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # TODO(busunkim): This method is in the base transport + # to avoid duplicating code across the transport classes. These functions + # should be deleted once the minimum required versions of google-auth is increased. + + # TODO: Remove this function once google-auth >= 1.25.0 is required + @classmethod + def _get_scopes_kwargs( + cls, host: str, scopes: Optional[Sequence[str]] + ) -> Dict[str, Optional[Sequence[str]]]: + """Returns scopes kwargs to pass to google-auth methods depending on the google-auth version""" + + scopes_kwargs = {} + + if _GOOGLE_AUTH_VERSION and ( + packaging.version.parse(_GOOGLE_AUTH_VERSION) + >= packaging.version.parse("1.25.0") + ): + scopes_kwargs = {"scopes": scopes, "default_scopes": cls.AUTH_SCOPES} + else: + scopes_kwargs = {"scopes": scopes or cls.AUTH_SCOPES} + + return scopes_kwargs + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.list_nodes: gapic_v1.method.wrap_method( + self.list_nodes, default_timeout=None, client_info=client_info, + ), + self.get_node: gapic_v1.method.wrap_method( + self.get_node, default_timeout=None, client_info=client_info, + ), + self.create_node: gapic_v1.method.wrap_method( + self.create_node, default_timeout=None, client_info=client_info, + ), + self.delete_node: gapic_v1.method.wrap_method( + self.delete_node, default_timeout=None, client_info=client_info, + ), + self.stop_node: gapic_v1.method.wrap_method( + self.stop_node, default_timeout=None, client_info=client_info, + ), + self.start_node: gapic_v1.method.wrap_method( + self.start_node, default_timeout=None, client_info=client_info, + ), + self.update_node: gapic_v1.method.wrap_method( + self.update_node, default_timeout=None, client_info=client_info, + ), + self.generate_service_identity: gapic_v1.method.wrap_method( + self.generate_service_identity, + default_timeout=None, + client_info=client_info, + ), + self.list_accelerator_types: gapic_v1.method.wrap_method( + self.list_accelerator_types, + default_timeout=None, + client_info=client_info, + ), + self.get_accelerator_type: gapic_v1.method.wrap_method( + self.get_accelerator_type, + default_timeout=None, + client_info=client_info, + ), + self.list_runtime_versions: gapic_v1.method.wrap_method( + self.list_runtime_versions, + default_timeout=None, + client_info=client_info, + ), + self.get_runtime_version: gapic_v1.method.wrap_method( + self.get_runtime_version, default_timeout=None, client_info=client_info, + ), + self.get_guest_attributes: gapic_v1.method.wrap_method( + self.get_guest_attributes, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def list_nodes( + self, + ) -> Callable[ + [cloud_tpu.ListNodesRequest], + Union[cloud_tpu.ListNodesResponse, Awaitable[cloud_tpu.ListNodesResponse]], + ]: + raise NotImplementedError() + + @property + def get_node( + self, + ) -> Callable[ + [cloud_tpu.GetNodeRequest], Union[cloud_tpu.Node, Awaitable[cloud_tpu.Node]] + ]: + raise NotImplementedError() + + @property + def create_node( + self, + ) -> Callable[ + [cloud_tpu.CreateNodeRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def delete_node( + self, + ) -> Callable[ + [cloud_tpu.DeleteNodeRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def stop_node( + self, + ) -> Callable[ + [cloud_tpu.StopNodeRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def start_node( + self, + ) -> Callable[ + [cloud_tpu.StartNodeRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def update_node( + self, + ) -> Callable[ + [cloud_tpu.UpdateNodeRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def generate_service_identity( + self, + ) -> Callable[ + [cloud_tpu.GenerateServiceIdentityRequest], + Union[ + cloud_tpu.GenerateServiceIdentityResponse, + Awaitable[cloud_tpu.GenerateServiceIdentityResponse], + ], + ]: + raise NotImplementedError() + + @property + def list_accelerator_types( + self, + ) -> Callable[ + [cloud_tpu.ListAcceleratorTypesRequest], + Union[ + cloud_tpu.ListAcceleratorTypesResponse, + Awaitable[cloud_tpu.ListAcceleratorTypesResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_accelerator_type( + self, + ) -> Callable[ + [cloud_tpu.GetAcceleratorTypeRequest], + Union[cloud_tpu.AcceleratorType, Awaitable[cloud_tpu.AcceleratorType]], + ]: + raise NotImplementedError() + + @property + def list_runtime_versions( + self, + ) -> Callable[ + [cloud_tpu.ListRuntimeVersionsRequest], + Union[ + cloud_tpu.ListRuntimeVersionsResponse, + Awaitable[cloud_tpu.ListRuntimeVersionsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_runtime_version( + self, + ) -> Callable[ + [cloud_tpu.GetRuntimeVersionRequest], + Union[cloud_tpu.RuntimeVersion, Awaitable[cloud_tpu.RuntimeVersion]], + ]: + raise NotImplementedError() + + @property + def get_guest_attributes( + self, + ) -> Callable[ + [cloud_tpu.GetGuestAttributesRequest], + Union[ + cloud_tpu.GetGuestAttributesResponse, + Awaitable[cloud_tpu.GetGuestAttributesResponse], + ], + ]: + raise NotImplementedError() + + +__all__ = ("TpuTransport",) diff --git a/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc.py b/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc.py new file mode 100644 index 0000000..e31acf6 --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc.py @@ -0,0 +1,597 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.tpu_v2alpha1.types import cloud_tpu +from google.longrunning import operations_pb2 # type: ignore +from .base import TpuTransport, DEFAULT_CLIENT_INFO + + +class TpuGrpcTransport(TpuTransport): + """gRPC backend transport for Tpu. + + Manages TPU nodes and other resources + TPU API v2alpha1 + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "tpu.googleapis.com", + credentials: ga_credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "tpu.googleapis.com", + credentials: ga_credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def list_nodes( + self, + ) -> Callable[[cloud_tpu.ListNodesRequest], cloud_tpu.ListNodesResponse]: + r"""Return a callable for the list nodes method over gRPC. + + Lists nodes. + + Returns: + Callable[[~.ListNodesRequest], + ~.ListNodesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_nodes" not in self._stubs: + self._stubs["list_nodes"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/ListNodes", + request_serializer=cloud_tpu.ListNodesRequest.serialize, + response_deserializer=cloud_tpu.ListNodesResponse.deserialize, + ) + return self._stubs["list_nodes"] + + @property + def get_node(self) -> Callable[[cloud_tpu.GetNodeRequest], cloud_tpu.Node]: + r"""Return a callable for the get node method over gRPC. + + Gets the details of a node. + + Returns: + Callable[[~.GetNodeRequest], + ~.Node]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_node" not in self._stubs: + self._stubs["get_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GetNode", + request_serializer=cloud_tpu.GetNodeRequest.serialize, + response_deserializer=cloud_tpu.Node.deserialize, + ) + return self._stubs["get_node"] + + @property + def create_node( + self, + ) -> Callable[[cloud_tpu.CreateNodeRequest], operations_pb2.Operation]: + r"""Return a callable for the create node method over gRPC. + + Creates a node. + + Returns: + Callable[[~.CreateNodeRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_node" not in self._stubs: + self._stubs["create_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/CreateNode", + request_serializer=cloud_tpu.CreateNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["create_node"] + + @property + def delete_node( + self, + ) -> Callable[[cloud_tpu.DeleteNodeRequest], operations_pb2.Operation]: + r"""Return a callable for the delete node method over gRPC. + + Deletes a node. + + Returns: + Callable[[~.DeleteNodeRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_node" not in self._stubs: + self._stubs["delete_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/DeleteNode", + request_serializer=cloud_tpu.DeleteNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_node"] + + @property + def stop_node( + self, + ) -> Callable[[cloud_tpu.StopNodeRequest], operations_pb2.Operation]: + r"""Return a callable for the stop node method over gRPC. + + Stops a node. This operation is only available with + single TPU nodes. + + Returns: + Callable[[~.StopNodeRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "stop_node" not in self._stubs: + self._stubs["stop_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/StopNode", + request_serializer=cloud_tpu.StopNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["stop_node"] + + @property + def start_node( + self, + ) -> Callable[[cloud_tpu.StartNodeRequest], operations_pb2.Operation]: + r"""Return a callable for the start node method over gRPC. + + Starts a node. + + Returns: + Callable[[~.StartNodeRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "start_node" not in self._stubs: + self._stubs["start_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/StartNode", + request_serializer=cloud_tpu.StartNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["start_node"] + + @property + def update_node( + self, + ) -> Callable[[cloud_tpu.UpdateNodeRequest], operations_pb2.Operation]: + r"""Return a callable for the update node method over gRPC. + + Updates the configurations of a node. + + Returns: + Callable[[~.UpdateNodeRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_node" not in self._stubs: + self._stubs["update_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/UpdateNode", + request_serializer=cloud_tpu.UpdateNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["update_node"] + + @property + def generate_service_identity( + self, + ) -> Callable[ + [cloud_tpu.GenerateServiceIdentityRequest], + cloud_tpu.GenerateServiceIdentityResponse, + ]: + r"""Return a callable for the generate service identity method over gRPC. + + Generates the Cloud TPU service identity for the + project. + + Returns: + Callable[[~.GenerateServiceIdentityRequest], + ~.GenerateServiceIdentityResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_service_identity" not in self._stubs: + self._stubs["generate_service_identity"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GenerateServiceIdentity", + request_serializer=cloud_tpu.GenerateServiceIdentityRequest.serialize, + response_deserializer=cloud_tpu.GenerateServiceIdentityResponse.deserialize, + ) + return self._stubs["generate_service_identity"] + + @property + def list_accelerator_types( + self, + ) -> Callable[ + [cloud_tpu.ListAcceleratorTypesRequest], cloud_tpu.ListAcceleratorTypesResponse + ]: + r"""Return a callable for the list accelerator types method over gRPC. + + Lists accelerator types supported by this API. + + Returns: + Callable[[~.ListAcceleratorTypesRequest], + ~.ListAcceleratorTypesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_accelerator_types" not in self._stubs: + self._stubs["list_accelerator_types"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/ListAcceleratorTypes", + request_serializer=cloud_tpu.ListAcceleratorTypesRequest.serialize, + response_deserializer=cloud_tpu.ListAcceleratorTypesResponse.deserialize, + ) + return self._stubs["list_accelerator_types"] + + @property + def get_accelerator_type( + self, + ) -> Callable[[cloud_tpu.GetAcceleratorTypeRequest], cloud_tpu.AcceleratorType]: + r"""Return a callable for the get accelerator type method over gRPC. + + Gets AcceleratorType. + + Returns: + Callable[[~.GetAcceleratorTypeRequest], + ~.AcceleratorType]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_accelerator_type" not in self._stubs: + self._stubs["get_accelerator_type"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GetAcceleratorType", + request_serializer=cloud_tpu.GetAcceleratorTypeRequest.serialize, + response_deserializer=cloud_tpu.AcceleratorType.deserialize, + ) + return self._stubs["get_accelerator_type"] + + @property + def list_runtime_versions( + self, + ) -> Callable[ + [cloud_tpu.ListRuntimeVersionsRequest], cloud_tpu.ListRuntimeVersionsResponse + ]: + r"""Return a callable for the list runtime versions method over gRPC. + + Lists runtime versions supported by this API. + + Returns: + Callable[[~.ListRuntimeVersionsRequest], + ~.ListRuntimeVersionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_runtime_versions" not in self._stubs: + self._stubs["list_runtime_versions"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/ListRuntimeVersions", + request_serializer=cloud_tpu.ListRuntimeVersionsRequest.serialize, + response_deserializer=cloud_tpu.ListRuntimeVersionsResponse.deserialize, + ) + return self._stubs["list_runtime_versions"] + + @property + def get_runtime_version( + self, + ) -> Callable[[cloud_tpu.GetRuntimeVersionRequest], cloud_tpu.RuntimeVersion]: + r"""Return a callable for the get runtime version method over gRPC. + + Gets a runtime version. + + Returns: + Callable[[~.GetRuntimeVersionRequest], + ~.RuntimeVersion]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_runtime_version" not in self._stubs: + self._stubs["get_runtime_version"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GetRuntimeVersion", + request_serializer=cloud_tpu.GetRuntimeVersionRequest.serialize, + response_deserializer=cloud_tpu.RuntimeVersion.deserialize, + ) + return self._stubs["get_runtime_version"] + + @property + def get_guest_attributes( + self, + ) -> Callable[ + [cloud_tpu.GetGuestAttributesRequest], cloud_tpu.GetGuestAttributesResponse + ]: + r"""Return a callable for the get guest attributes method over gRPC. + + Retrieves the guest attributes for the node. + + Returns: + Callable[[~.GetGuestAttributesRequest], + ~.GetGuestAttributesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_guest_attributes" not in self._stubs: + self._stubs["get_guest_attributes"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GetGuestAttributes", + request_serializer=cloud_tpu.GetGuestAttributesRequest.serialize, + response_deserializer=cloud_tpu.GetGuestAttributesResponse.deserialize, + ) + return self._stubs["get_guest_attributes"] + + def close(self): + self.grpc_channel.close() + + +__all__ = ("TpuGrpcTransport",) diff --git a/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc_asyncio.py b/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc_asyncio.py new file mode 100644 index 0000000..9ef622e --- /dev/null +++ b/google/cloud/tpu_v2alpha1/services/tpu/transports/grpc_asyncio.py @@ -0,0 +1,611 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +import packaging.version + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.tpu_v2alpha1.types import cloud_tpu +from google.longrunning import operations_pb2 # type: ignore +from .base import TpuTransport, DEFAULT_CLIENT_INFO +from .grpc import TpuGrpcTransport + + +class TpuGrpcAsyncIOTransport(TpuTransport): + """gRPC AsyncIO backend transport for Tpu. + + Manages TPU nodes and other resources + TPU API v2alpha1 + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "tpu.googleapis.com", + credentials: ga_credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "tpu.googleapis.com", + credentials: ga_credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + credentials=self._credentials, + credentials_file=credentials_file, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def list_nodes( + self, + ) -> Callable[[cloud_tpu.ListNodesRequest], Awaitable[cloud_tpu.ListNodesResponse]]: + r"""Return a callable for the list nodes method over gRPC. + + Lists nodes. + + Returns: + Callable[[~.ListNodesRequest], + Awaitable[~.ListNodesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_nodes" not in self._stubs: + self._stubs["list_nodes"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/ListNodes", + request_serializer=cloud_tpu.ListNodesRequest.serialize, + response_deserializer=cloud_tpu.ListNodesResponse.deserialize, + ) + return self._stubs["list_nodes"] + + @property + def get_node( + self, + ) -> Callable[[cloud_tpu.GetNodeRequest], Awaitable[cloud_tpu.Node]]: + r"""Return a callable for the get node method over gRPC. + + Gets the details of a node. + + Returns: + Callable[[~.GetNodeRequest], + Awaitable[~.Node]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_node" not in self._stubs: + self._stubs["get_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GetNode", + request_serializer=cloud_tpu.GetNodeRequest.serialize, + response_deserializer=cloud_tpu.Node.deserialize, + ) + return self._stubs["get_node"] + + @property + def create_node( + self, + ) -> Callable[[cloud_tpu.CreateNodeRequest], Awaitable[operations_pb2.Operation]]: + r"""Return a callable for the create node method over gRPC. + + Creates a node. + + Returns: + Callable[[~.CreateNodeRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_node" not in self._stubs: + self._stubs["create_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/CreateNode", + request_serializer=cloud_tpu.CreateNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["create_node"] + + @property + def delete_node( + self, + ) -> Callable[[cloud_tpu.DeleteNodeRequest], Awaitable[operations_pb2.Operation]]: + r"""Return a callable for the delete node method over gRPC. + + Deletes a node. + + Returns: + Callable[[~.DeleteNodeRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_node" not in self._stubs: + self._stubs["delete_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/DeleteNode", + request_serializer=cloud_tpu.DeleteNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["delete_node"] + + @property + def stop_node( + self, + ) -> Callable[[cloud_tpu.StopNodeRequest], Awaitable[operations_pb2.Operation]]: + r"""Return a callable for the stop node method over gRPC. + + Stops a node. This operation is only available with + single TPU nodes. + + Returns: + Callable[[~.StopNodeRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "stop_node" not in self._stubs: + self._stubs["stop_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/StopNode", + request_serializer=cloud_tpu.StopNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["stop_node"] + + @property + def start_node( + self, + ) -> Callable[[cloud_tpu.StartNodeRequest], Awaitable[operations_pb2.Operation]]: + r"""Return a callable for the start node method over gRPC. + + Starts a node. + + Returns: + Callable[[~.StartNodeRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "start_node" not in self._stubs: + self._stubs["start_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/StartNode", + request_serializer=cloud_tpu.StartNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["start_node"] + + @property + def update_node( + self, + ) -> Callable[[cloud_tpu.UpdateNodeRequest], Awaitable[operations_pb2.Operation]]: + r"""Return a callable for the update node method over gRPC. + + Updates the configurations of a node. + + Returns: + Callable[[~.UpdateNodeRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_node" not in self._stubs: + self._stubs["update_node"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/UpdateNode", + request_serializer=cloud_tpu.UpdateNodeRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["update_node"] + + @property + def generate_service_identity( + self, + ) -> Callable[ + [cloud_tpu.GenerateServiceIdentityRequest], + Awaitable[cloud_tpu.GenerateServiceIdentityResponse], + ]: + r"""Return a callable for the generate service identity method over gRPC. + + Generates the Cloud TPU service identity for the + project. + + Returns: + Callable[[~.GenerateServiceIdentityRequest], + Awaitable[~.GenerateServiceIdentityResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_service_identity" not in self._stubs: + self._stubs["generate_service_identity"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GenerateServiceIdentity", + request_serializer=cloud_tpu.GenerateServiceIdentityRequest.serialize, + response_deserializer=cloud_tpu.GenerateServiceIdentityResponse.deserialize, + ) + return self._stubs["generate_service_identity"] + + @property + def list_accelerator_types( + self, + ) -> Callable[ + [cloud_tpu.ListAcceleratorTypesRequest], + Awaitable[cloud_tpu.ListAcceleratorTypesResponse], + ]: + r"""Return a callable for the list accelerator types method over gRPC. + + Lists accelerator types supported by this API. + + Returns: + Callable[[~.ListAcceleratorTypesRequest], + Awaitable[~.ListAcceleratorTypesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_accelerator_types" not in self._stubs: + self._stubs["list_accelerator_types"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/ListAcceleratorTypes", + request_serializer=cloud_tpu.ListAcceleratorTypesRequest.serialize, + response_deserializer=cloud_tpu.ListAcceleratorTypesResponse.deserialize, + ) + return self._stubs["list_accelerator_types"] + + @property + def get_accelerator_type( + self, + ) -> Callable[ + [cloud_tpu.GetAcceleratorTypeRequest], Awaitable[cloud_tpu.AcceleratorType] + ]: + r"""Return a callable for the get accelerator type method over gRPC. + + Gets AcceleratorType. + + Returns: + Callable[[~.GetAcceleratorTypeRequest], + Awaitable[~.AcceleratorType]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_accelerator_type" not in self._stubs: + self._stubs["get_accelerator_type"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GetAcceleratorType", + request_serializer=cloud_tpu.GetAcceleratorTypeRequest.serialize, + response_deserializer=cloud_tpu.AcceleratorType.deserialize, + ) + return self._stubs["get_accelerator_type"] + + @property + def list_runtime_versions( + self, + ) -> Callable[ + [cloud_tpu.ListRuntimeVersionsRequest], + Awaitable[cloud_tpu.ListRuntimeVersionsResponse], + ]: + r"""Return a callable for the list runtime versions method over gRPC. + + Lists runtime versions supported by this API. + + Returns: + Callable[[~.ListRuntimeVersionsRequest], + Awaitable[~.ListRuntimeVersionsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_runtime_versions" not in self._stubs: + self._stubs["list_runtime_versions"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/ListRuntimeVersions", + request_serializer=cloud_tpu.ListRuntimeVersionsRequest.serialize, + response_deserializer=cloud_tpu.ListRuntimeVersionsResponse.deserialize, + ) + return self._stubs["list_runtime_versions"] + + @property + def get_runtime_version( + self, + ) -> Callable[ + [cloud_tpu.GetRuntimeVersionRequest], Awaitable[cloud_tpu.RuntimeVersion] + ]: + r"""Return a callable for the get runtime version method over gRPC. + + Gets a runtime version. + + Returns: + Callable[[~.GetRuntimeVersionRequest], + Awaitable[~.RuntimeVersion]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_runtime_version" not in self._stubs: + self._stubs["get_runtime_version"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GetRuntimeVersion", + request_serializer=cloud_tpu.GetRuntimeVersionRequest.serialize, + response_deserializer=cloud_tpu.RuntimeVersion.deserialize, + ) + return self._stubs["get_runtime_version"] + + @property + def get_guest_attributes( + self, + ) -> Callable[ + [cloud_tpu.GetGuestAttributesRequest], + Awaitable[cloud_tpu.GetGuestAttributesResponse], + ]: + r"""Return a callable for the get guest attributes method over gRPC. + + Retrieves the guest attributes for the node. + + Returns: + Callable[[~.GetGuestAttributesRequest], + Awaitable[~.GetGuestAttributesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_guest_attributes" not in self._stubs: + self._stubs["get_guest_attributes"] = self.grpc_channel.unary_unary( + "/google.cloud.tpu.v2alpha1.Tpu/GetGuestAttributes", + request_serializer=cloud_tpu.GetGuestAttributesRequest.serialize, + response_deserializer=cloud_tpu.GetGuestAttributesResponse.deserialize, + ) + return self._stubs["get_guest_attributes"] + + def close(self): + return self.grpc_channel.close() + + +__all__ = ("TpuGrpcAsyncIOTransport",) diff --git a/google/cloud/tpu_v2alpha1/types/__init__.py b/google/cloud/tpu_v2alpha1/types/__init__.py new file mode 100644 index 0000000..fd80d1e --- /dev/null +++ b/google/cloud/tpu_v2alpha1/types/__init__.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .cloud_tpu import ( + AcceleratorType, + AccessConfig, + AttachedDisk, + CreateNodeRequest, + DeleteNodeRequest, + GenerateServiceIdentityRequest, + GenerateServiceIdentityResponse, + GetAcceleratorTypeRequest, + GetGuestAttributesRequest, + GetGuestAttributesResponse, + GetNodeRequest, + GetRuntimeVersionRequest, + GuestAttributes, + GuestAttributesEntry, + GuestAttributesValue, + ListAcceleratorTypesRequest, + ListAcceleratorTypesResponse, + ListNodesRequest, + ListNodesResponse, + ListRuntimeVersionsRequest, + ListRuntimeVersionsResponse, + NetworkConfig, + NetworkEndpoint, + Node, + OperationMetadata, + RuntimeVersion, + SchedulingConfig, + ServiceAccount, + ServiceIdentity, + StartNodeRequest, + StopNodeRequest, + Symptom, + UpdateNodeRequest, +) + +__all__ = ( + "AcceleratorType", + "AccessConfig", + "AttachedDisk", + "CreateNodeRequest", + "DeleteNodeRequest", + "GenerateServiceIdentityRequest", + "GenerateServiceIdentityResponse", + "GetAcceleratorTypeRequest", + "GetGuestAttributesRequest", + "GetGuestAttributesResponse", + "GetNodeRequest", + "GetRuntimeVersionRequest", + "GuestAttributes", + "GuestAttributesEntry", + "GuestAttributesValue", + "ListAcceleratorTypesRequest", + "ListAcceleratorTypesResponse", + "ListNodesRequest", + "ListNodesResponse", + "ListRuntimeVersionsRequest", + "ListRuntimeVersionsResponse", + "NetworkConfig", + "NetworkEndpoint", + "Node", + "OperationMetadata", + "RuntimeVersion", + "SchedulingConfig", + "ServiceAccount", + "ServiceIdentity", + "StartNodeRequest", + "StopNodeRequest", + "Symptom", + "UpdateNodeRequest", +) diff --git a/google/cloud/tpu_v2alpha1/types/cloud_tpu.py b/google/cloud/tpu_v2alpha1/types/cloud_tpu.py new file mode 100644 index 0000000..e6d937c --- /dev/null +++ b/google/cloud/tpu_v2alpha1/types/cloud_tpu.py @@ -0,0 +1,766 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import proto # type: ignore + +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.tpu.v2alpha1", + manifest={ + "GuestAttributes", + "GuestAttributesValue", + "GuestAttributesEntry", + "AttachedDisk", + "SchedulingConfig", + "NetworkEndpoint", + "AccessConfig", + "NetworkConfig", + "ServiceAccount", + "Node", + "ListNodesRequest", + "ListNodesResponse", + "GetNodeRequest", + "CreateNodeRequest", + "DeleteNodeRequest", + "StopNodeRequest", + "StartNodeRequest", + "UpdateNodeRequest", + "ServiceIdentity", + "GenerateServiceIdentityRequest", + "GenerateServiceIdentityResponse", + "AcceleratorType", + "GetAcceleratorTypeRequest", + "ListAcceleratorTypesRequest", + "ListAcceleratorTypesResponse", + "OperationMetadata", + "RuntimeVersion", + "GetRuntimeVersionRequest", + "ListRuntimeVersionsRequest", + "ListRuntimeVersionsResponse", + "Symptom", + "GetGuestAttributesRequest", + "GetGuestAttributesResponse", + }, +) + + +class GuestAttributes(proto.Message): + r"""A guest attributes. + + Attributes: + query_path (str): + The path to be queried. This can be the + default namespace ('/') or a nested namespace + ('/\/') or a specified key + ('/\/\') + query_value (google.cloud.tpu_v2alpha1.types.GuestAttributesValue): + The value of the requested queried path. + """ + + query_path = proto.Field(proto.STRING, number=1,) + query_value = proto.Field(proto.MESSAGE, number=2, message="GuestAttributesValue",) + + +class GuestAttributesValue(proto.Message): + r"""Array of guest attribute namespace/key/value tuples. + + Attributes: + items (Sequence[google.cloud.tpu_v2alpha1.types.GuestAttributesEntry]): + The list of guest attributes entries. + """ + + items = proto.RepeatedField( + proto.MESSAGE, number=1, message="GuestAttributesEntry", + ) + + +class GuestAttributesEntry(proto.Message): + r"""A guest attributes namespace/key/value entry. + + Attributes: + namespace (str): + Namespace for the guest attribute entry. + key (str): + Key for the guest attribute entry. + value (str): + Value for the guest attribute entry. + """ + + namespace = proto.Field(proto.STRING, number=1,) + key = proto.Field(proto.STRING, number=2,) + value = proto.Field(proto.STRING, number=3,) + + +class AttachedDisk(proto.Message): + r"""A node-attached disk resource. + Next ID: 8; + + Attributes: + source_disk (str): + Specifies the full path to an existing disk. + For example: "projects/my-project/zones/us- + central1-c/disks/my-disk". + mode (google.cloud.tpu_v2alpha1.types.AttachedDisk.DiskMode): + The mode in which to attach this disk. If not specified, the + default is READ_WRITE mode. Only applicable to data_disks. + """ + + class DiskMode(proto.Enum): + r"""The different mode of the attached disk.""" + DISK_MODE_UNSPECIFIED = 0 + READ_WRITE = 1 + READ_ONLY = 2 + + source_disk = proto.Field(proto.STRING, number=3,) + mode = proto.Field(proto.ENUM, number=4, enum=DiskMode,) + + +class SchedulingConfig(proto.Message): + r"""Sets the scheduling options for this node. + + Attributes: + preemptible (bool): + Defines whether the node is preemptible. + reserved (bool): + Whether the node is created under a + reservation. + """ + + preemptible = proto.Field(proto.BOOL, number=1,) + reserved = proto.Field(proto.BOOL, number=2,) + + +class NetworkEndpoint(proto.Message): + r"""A network endpoint over which a TPU worker can be reached. + + Attributes: + ip_address (str): + The internal IP address of this network + endpoint. + port (int): + The port of this network endpoint. + access_config (google.cloud.tpu_v2alpha1.types.AccessConfig): + The access config for the TPU worker. + """ + + ip_address = proto.Field(proto.STRING, number=1,) + port = proto.Field(proto.INT32, number=2,) + access_config = proto.Field(proto.MESSAGE, number=5, message="AccessConfig",) + + +class AccessConfig(proto.Message): + r"""An access config attached to the TPU worker. + + Attributes: + external_ip (str): + Output only. An external IP address + associated with the TPU worker. + """ + + external_ip = proto.Field(proto.STRING, number=1,) + + +class NetworkConfig(proto.Message): + r"""Network related configurations. + + Attributes: + network (str): + The name of the network for the TPU node. It + must be a preexisting Google Compute Engine + network. If none is provided, "default" will be + used. + subnetwork (str): + The name of the subnetwork for the TPU node. + It must be a preexisting Google Compute Engine + subnetwork. If none is provided, "default" will + be used. + enable_external_ips (bool): + Indicates that external IP addresses would be + associated with the TPU workers. If set to + false, the specified subnetwork or network + should have Private Google Access enabled. + """ + + network = proto.Field(proto.STRING, number=1,) + subnetwork = proto.Field(proto.STRING, number=2,) + enable_external_ips = proto.Field(proto.BOOL, number=3,) + + +class ServiceAccount(proto.Message): + r"""A service account. + + Attributes: + email (str): + Email address of the service account. If + empty, default Compute service account will be + used. + scope (Sequence[str]): + The list of scopes to be made available for + this service account. If empty, access to all + Cloud APIs will be allowed. + """ + + email = proto.Field(proto.STRING, number=1,) + scope = proto.RepeatedField(proto.STRING, number=2,) + + +class Node(proto.Message): + r"""A TPU instance. + + Attributes: + name (str): + Output only. Immutable. The name of the TPU. + description (str): + The user-supplied description of the TPU. + Maximum of 512 characters. + accelerator_type (str): + Required. The type of hardware accelerators + associated with this node. + state (google.cloud.tpu_v2alpha1.types.Node.State): + Output only. The current state for the TPU + Node. + health_description (str): + Output only. If this field is populated, it + contains a description of why the TPU Node is + unhealthy. + runtime_version (str): + Required. The runtime version running in the + Node. + network_config (google.cloud.tpu_v2alpha1.types.NetworkConfig): + Network configurations for the TPU node. + cidr_block (str): + The CIDR block that the TPU node will use + when selecting an IP address. This CIDR block + must be a /29 block; the Compute Engine networks + API forbids a smaller block, and using a larger + block would be wasteful (a node can only consume + one IP address). Errors will occur if the CIDR + block has already been used for a currently + existing TPU node, the CIDR block conflicts with + any subnetworks in the user's provided network, + or the provided network is peered with another + network that is using that CIDR block. + service_account (google.cloud.tpu_v2alpha1.types.ServiceAccount): + The Google Cloud Platform Service Account to + be used by the TPU node VMs. If None is + specified, the default compute service account + will be used. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The time when the node was + created. + scheduling_config (google.cloud.tpu_v2alpha1.types.SchedulingConfig): + The scheduling options for this node. + network_endpoints (Sequence[google.cloud.tpu_v2alpha1.types.NetworkEndpoint]): + Output only. The network endpoints where TPU + workers can be accessed and sent work. It is + recommended that runtime clients of the node + reach out to the 0th entry in this map first. + health (google.cloud.tpu_v2alpha1.types.Node.Health): + The health status of the TPU node. + labels (Sequence[google.cloud.tpu_v2alpha1.types.Node.LabelsEntry]): + Resource labels to represent user-provided + metadata. + metadata (Sequence[google.cloud.tpu_v2alpha1.types.Node.MetadataEntry]): + Custom metadata to apply to the TPU Node. + Can set startup-script and shutdown-script + tags (Sequence[str]): + Tags to apply to the TPU Node. Tags are used + to identify valid sources or targets for network + firewalls. + id (int): + Output only. The unique identifier for the + TPU Node. + data_disks (Sequence[google.cloud.tpu_v2alpha1.types.AttachedDisk]): + The additional data disks for the Node. + api_version (google.cloud.tpu_v2alpha1.types.Node.ApiVersion): + Output only. The API version that created + this Node. + symptoms (Sequence[google.cloud.tpu_v2alpha1.types.Symptom]): + Output only. The Symptoms that have occurred + to the TPU Node. + """ + + class State(proto.Enum): + r"""Represents the different states of a TPU node during its + lifecycle. + """ + STATE_UNSPECIFIED = 0 + CREATING = 1 + READY = 2 + RESTARTING = 3 + REIMAGING = 4 + DELETING = 5 + REPAIRING = 6 + STOPPED = 8 + STOPPING = 9 + STARTING = 10 + PREEMPTED = 11 + TERMINATED = 12 + HIDING = 13 + HIDDEN = 14 + UNHIDING = 15 + + class Health(proto.Enum): + r"""Health defines the status of a TPU node as reported by + Health Monitor. + """ + HEALTH_UNSPECIFIED = 0 + HEALTHY = 1 + TIMEOUT = 3 + UNHEALTHY_TENSORFLOW = 4 + UNHEALTHY_MAINTENANCE = 5 + + class ApiVersion(proto.Enum): + r"""TPU API Version.""" + API_VERSION_UNSPECIFIED = 0 + V1_ALPHA1 = 1 + V1 = 2 + V2_ALPHA1 = 3 + + name = proto.Field(proto.STRING, number=1,) + description = proto.Field(proto.STRING, number=3,) + accelerator_type = proto.Field(proto.STRING, number=5,) + state = proto.Field(proto.ENUM, number=9, enum=State,) + health_description = proto.Field(proto.STRING, number=10,) + runtime_version = proto.Field(proto.STRING, number=11,) + network_config = proto.Field(proto.MESSAGE, number=36, message="NetworkConfig",) + cidr_block = proto.Field(proto.STRING, number=13,) + service_account = proto.Field(proto.MESSAGE, number=37, message="ServiceAccount",) + create_time = proto.Field( + proto.MESSAGE, number=16, message=timestamp_pb2.Timestamp, + ) + scheduling_config = proto.Field( + proto.MESSAGE, number=17, message="SchedulingConfig", + ) + network_endpoints = proto.RepeatedField( + proto.MESSAGE, number=21, message="NetworkEndpoint", + ) + health = proto.Field(proto.ENUM, number=22, enum=Health,) + labels = proto.MapField(proto.STRING, proto.STRING, number=24,) + metadata = proto.MapField(proto.STRING, proto.STRING, number=34,) + tags = proto.RepeatedField(proto.STRING, number=40,) + id = proto.Field(proto.INT64, number=33,) + data_disks = proto.RepeatedField(proto.MESSAGE, number=41, message="AttachedDisk",) + api_version = proto.Field(proto.ENUM, number=38, enum=ApiVersion,) + symptoms = proto.RepeatedField(proto.MESSAGE, number=39, message="Symptom",) + + +class ListNodesRequest(proto.Message): + r"""Request for [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes]. + + Attributes: + parent (str): + Required. The parent resource name. + page_size (int): + The maximum number of items to return. + page_token (str): + The next_page_token value returned from a previous List + request, if any. + """ + + parent = proto.Field(proto.STRING, number=1,) + page_size = proto.Field(proto.INT32, number=2,) + page_token = proto.Field(proto.STRING, number=3,) + + +class ListNodesResponse(proto.Message): + r"""Response for [ListNodes][google.cloud.tpu.v2alpha1.Tpu.ListNodes]. + + Attributes: + nodes (Sequence[google.cloud.tpu_v2alpha1.types.Node]): + The listed nodes. + next_page_token (str): + The next page token or empty if none. + unreachable (Sequence[str]): + Locations that could not be reached. + """ + + @property + def raw_page(self): + return self + + nodes = proto.RepeatedField(proto.MESSAGE, number=1, message="Node",) + next_page_token = proto.Field(proto.STRING, number=2,) + unreachable = proto.RepeatedField(proto.STRING, number=3,) + + +class GetNodeRequest(proto.Message): + r"""Request for [GetNode][google.cloud.tpu.v2alpha1.Tpu.GetNode]. + + Attributes: + name (str): + Required. The resource name. + """ + + name = proto.Field(proto.STRING, number=1,) + + +class CreateNodeRequest(proto.Message): + r"""Request for [CreateNode][google.cloud.tpu.v2alpha1.Tpu.CreateNode]. + + Attributes: + parent (str): + Required. The parent resource name. + node_id (str): + The unqualified resource name. + node (google.cloud.tpu_v2alpha1.types.Node): + Required. The node. + """ + + parent = proto.Field(proto.STRING, number=1,) + node_id = proto.Field(proto.STRING, number=2,) + node = proto.Field(proto.MESSAGE, number=3, message="Node",) + + +class DeleteNodeRequest(proto.Message): + r"""Request for [DeleteNode][google.cloud.tpu.v2alpha1.Tpu.DeleteNode]. + + Attributes: + name (str): + Required. The resource name. + """ + + name = proto.Field(proto.STRING, number=1,) + + +class StopNodeRequest(proto.Message): + r"""Request for [StopNode][google.cloud.tpu.v2alpha1.Tpu.StopNode]. + + Attributes: + name (str): + The resource name. + """ + + name = proto.Field(proto.STRING, number=1,) + + +class StartNodeRequest(proto.Message): + r"""Request for [StartNode][google.cloud.tpu.v2alpha1.Tpu.StartNode]. + + Attributes: + name (str): + The resource name. + """ + + name = proto.Field(proto.STRING, number=1,) + + +class UpdateNodeRequest(proto.Message): + r"""Request for [UpdateNode][google.cloud.tpu.v2alpha1.Tpu.UpdateNode]. + + Attributes: + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. Mask of fields from [Node][Tpu.Node] to update. + Supported fields: None. + node (google.cloud.tpu_v2alpha1.types.Node): + Required. The node. Only fields specified in update_mask are + updated. + """ + + update_mask = proto.Field( + proto.MESSAGE, number=1, message=field_mask_pb2.FieldMask, + ) + node = proto.Field(proto.MESSAGE, number=2, message="Node",) + + +class ServiceIdentity(proto.Message): + r"""The per-product per-project service identity for Cloud TPU + service. + + Attributes: + email (str): + The email address of the service identity. + """ + + email = proto.Field(proto.STRING, number=1,) + + +class GenerateServiceIdentityRequest(proto.Message): + r"""Request for + [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity]. + + Attributes: + parent (str): + Required. The parent resource name. + """ + + parent = proto.Field(proto.STRING, number=1,) + + +class GenerateServiceIdentityResponse(proto.Message): + r"""Response for + [GenerateServiceIdentity][google.cloud.tpu.v2alpha1.Tpu.GenerateServiceIdentity]. + + Attributes: + identity (google.cloud.tpu_v2alpha1.types.ServiceIdentity): + ServiceIdentity that was created or + retrieved. + """ + + identity = proto.Field(proto.MESSAGE, number=1, message="ServiceIdentity",) + + +class AcceleratorType(proto.Message): + r"""A accelerator type that a Node can be configured with. + + Attributes: + name (str): + The resource name. + type_ (str): + the accelerator type. + """ + + name = proto.Field(proto.STRING, number=1,) + type_ = proto.Field(proto.STRING, number=2,) + + +class GetAcceleratorTypeRequest(proto.Message): + r"""Request for + [GetAcceleratorType][google.cloud.tpu.v2alpha1.Tpu.GetAcceleratorType]. + + Attributes: + name (str): + Required. The resource name. + """ + + name = proto.Field(proto.STRING, number=1,) + + +class ListAcceleratorTypesRequest(proto.Message): + r"""Request for + [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes]. + + Attributes: + parent (str): + Required. The parent resource name. + page_size (int): + The maximum number of items to return. + page_token (str): + The next_page_token value returned from a previous List + request, if any. + filter (str): + List filter. + order_by (str): + Sort results. + """ + + parent = proto.Field(proto.STRING, number=1,) + page_size = proto.Field(proto.INT32, number=2,) + page_token = proto.Field(proto.STRING, number=3,) + filter = proto.Field(proto.STRING, number=5,) + order_by = proto.Field(proto.STRING, number=6,) + + +class ListAcceleratorTypesResponse(proto.Message): + r"""Response for + [ListAcceleratorTypes][google.cloud.tpu.v2alpha1.Tpu.ListAcceleratorTypes]. + + Attributes: + accelerator_types (Sequence[google.cloud.tpu_v2alpha1.types.AcceleratorType]): + The listed nodes. + next_page_token (str): + The next page token or empty if none. + unreachable (Sequence[str]): + Locations that could not be reached. + """ + + @property + def raw_page(self): + return self + + accelerator_types = proto.RepeatedField( + proto.MESSAGE, number=1, message="AcceleratorType", + ) + next_page_token = proto.Field(proto.STRING, number=2,) + unreachable = proto.RepeatedField(proto.STRING, number=3,) + + +class OperationMetadata(proto.Message): + r"""Metadata describing an [Operation][google.longrunning.Operation] + + Attributes: + create_time (google.protobuf.timestamp_pb2.Timestamp): + The time the operation was created. + end_time (google.protobuf.timestamp_pb2.Timestamp): + The time the operation finished running. + target (str): + Target of the operation - for example + projects/project-1/connectivityTests/test-1 + verb (str): + Name of the verb executed by the operation. + status_detail (str): + Human-readable status of the operation, if + any. + cancel_requested (bool): + Specifies if cancellation was requested for + the operation. + api_version (str): + API version. + """ + + create_time = proto.Field(proto.MESSAGE, number=1, message=timestamp_pb2.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=2, message=timestamp_pb2.Timestamp,) + target = proto.Field(proto.STRING, number=3,) + verb = proto.Field(proto.STRING, number=4,) + status_detail = proto.Field(proto.STRING, number=5,) + cancel_requested = proto.Field(proto.BOOL, number=6,) + api_version = proto.Field(proto.STRING, number=7,) + + +class RuntimeVersion(proto.Message): + r"""A runtime version that a Node can be configured with. + + Attributes: + name (str): + The resource name. + version (str): + The runtime version. + """ + + name = proto.Field(proto.STRING, number=1,) + version = proto.Field(proto.STRING, number=2,) + + +class GetRuntimeVersionRequest(proto.Message): + r"""Request for + [GetRuntimeVersion][google.cloud.tpu.v2alpha1.Tpu.GetRuntimeVersion]. + + Attributes: + name (str): + Required. The resource name. + """ + + name = proto.Field(proto.STRING, number=1,) + + +class ListRuntimeVersionsRequest(proto.Message): + r"""Request for + [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions]. + + Attributes: + parent (str): + Required. The parent resource name. + page_size (int): + The maximum number of items to return. + page_token (str): + The next_page_token value returned from a previous List + request, if any. + filter (str): + List filter. + order_by (str): + Sort results. + """ + + parent = proto.Field(proto.STRING, number=1,) + page_size = proto.Field(proto.INT32, number=2,) + page_token = proto.Field(proto.STRING, number=3,) + filter = proto.Field(proto.STRING, number=5,) + order_by = proto.Field(proto.STRING, number=6,) + + +class ListRuntimeVersionsResponse(proto.Message): + r"""Response for + [ListRuntimeVersions][google.cloud.tpu.v2alpha1.Tpu.ListRuntimeVersions]. + + Attributes: + runtime_versions (Sequence[google.cloud.tpu_v2alpha1.types.RuntimeVersion]): + The listed nodes. + next_page_token (str): + The next page token or empty if none. + unreachable (Sequence[str]): + Locations that could not be reached. + """ + + @property + def raw_page(self): + return self + + runtime_versions = proto.RepeatedField( + proto.MESSAGE, number=1, message="RuntimeVersion", + ) + next_page_token = proto.Field(proto.STRING, number=2,) + unreachable = proto.RepeatedField(proto.STRING, number=3,) + + +class Symptom(proto.Message): + r"""A Symptom instance. + + Attributes: + create_time (google.protobuf.timestamp_pb2.Timestamp): + Timestamp when the Symptom is created. + symptom_type (google.cloud.tpu_v2alpha1.types.Symptom.SymptomType): + Type of the Symptom. + details (str): + Detailed information of the current Symptom. + worker_id (str): + A string used to uniquely distinguish a + worker within a TPU node. + """ + + class SymptomType(proto.Enum): + r"""SymptomType represents the different types of Symptoms that a + TPU can be at. + """ + SYMPTOM_TYPE_UNSPECIFIED = 0 + LOW_MEMORY = 1 + OUT_OF_MEMORY = 2 + EXECUTE_TIMED_OUT = 3 + MESH_BUILD_FAIL = 4 + HBM_OUT_OF_MEMORY = 5 + PROJECT_ABUSE = 6 + + create_time = proto.Field(proto.MESSAGE, number=1, message=timestamp_pb2.Timestamp,) + symptom_type = proto.Field(proto.ENUM, number=2, enum=SymptomType,) + details = proto.Field(proto.STRING, number=3,) + worker_id = proto.Field(proto.STRING, number=4,) + + +class GetGuestAttributesRequest(proto.Message): + r"""Request for + [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes]. + + Attributes: + name (str): + Required. The resource name. + query_path (str): + The guest attributes path to be queried. + worker_ids (Sequence[str]): + The 0-based worker ID. If it is empty, all + workers' GuestAttributes will be returned. + """ + + name = proto.Field(proto.STRING, number=1,) + query_path = proto.Field(proto.STRING, number=2,) + worker_ids = proto.RepeatedField(proto.STRING, number=3,) + + +class GetGuestAttributesResponse(proto.Message): + r"""Response for + [GetGuestAttributes][google.cloud.tpu.v2alpha1.Tpu.GetGuestAttributes]. + + Attributes: + guest_attributes (Sequence[google.cloud.tpu_v2alpha1.types.GuestAttributes]): + The guest attributes for the TPU workers. + """ + + guest_attributes = proto.RepeatedField( + proto.MESSAGE, number=1, message="GuestAttributes", + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/scripts/fixup_tpu_v2alpha1_keywords.py b/scripts/fixup_tpu_v2alpha1_keywords.py new file mode 100644 index 0000000..10a2b39 --- /dev/null +++ b/scripts/fixup_tpu_v2alpha1_keywords.py @@ -0,0 +1,188 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import argparse +import os +import libcst as cst +import pathlib +import sys +from typing import (Any, Callable, Dict, List, Sequence, Tuple) + + +def partition( + predicate: Callable[[Any], bool], + iterator: Sequence[Any] +) -> Tuple[List[Any], List[Any]]: + """A stable, out-of-place partition.""" + results = ([], []) + + for i in iterator: + results[int(predicate(i))].append(i) + + # Returns trueList, falseList + return results[1], results[0] + + +class tpuCallTransformer(cst.CSTTransformer): + CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') + METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { + 'create_node': ('parent', 'node', 'node_id', ), + 'delete_node': ('name', ), + 'generate_service_identity': ('parent', ), + 'get_accelerator_type': ('name', ), + 'get_guest_attributes': ('name', 'query_path', 'worker_ids', ), + 'get_node': ('name', ), + 'get_runtime_version': ('name', ), + 'list_accelerator_types': ('parent', 'page_size', 'page_token', 'filter', 'order_by', ), + 'list_nodes': ('parent', 'page_size', 'page_token', ), + 'list_runtime_versions': ('parent', 'page_size', 'page_token', 'filter', 'order_by', ), + 'start_node': ('name', ), + 'stop_node': ('name', ), + 'update_node': ('update_mask', 'node', ), + } + + def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: + try: + key = original.func.attr.value + kword_params = self.METHOD_TO_PARAMS[key] + except (AttributeError, KeyError): + # Either not a method from the API or too convoluted to be sure. + return updated + + # If the existing code is valid, keyword args come after positional args. + # Therefore, all positional args must map to the first parameters. + args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) + if any(k.keyword.value == "request" for k in kwargs): + # We've already fixed this file, don't fix it again. + return updated + + kwargs, ctrl_kwargs = partition( + lambda a: a.keyword.value not in self.CTRL_PARAMS, + kwargs + ) + + args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] + ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) + for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) + + request_arg = cst.Arg( + value=cst.Dict([ + cst.DictElement( + cst.SimpleString("'{}'".format(name)), +cst.Element(value=arg.value) + ) + # Note: the args + kwargs looks silly, but keep in mind that + # the control parameters had to be stripped out, and that + # those could have been passed positionally or by keyword. + for name, arg in zip(kword_params, args + kwargs)]), + keyword=cst.Name("request") + ) + + return updated.with_changes( + args=[request_arg] + ctrl_kwargs + ) + + +def fix_files( + in_dir: pathlib.Path, + out_dir: pathlib.Path, + *, + transformer=tpuCallTransformer(), +): + """Duplicate the input dir to the output dir, fixing file method calls. + + Preconditions: + * in_dir is a real directory + * out_dir is a real, empty directory + """ + pyfile_gen = ( + pathlib.Path(os.path.join(root, f)) + for root, _, files in os.walk(in_dir) + for f in files if os.path.splitext(f)[1] == ".py" + ) + + for fpath in pyfile_gen: + with open(fpath, 'r') as f: + src = f.read() + + # Parse the code and insert method call fixes. + tree = cst.parse_module(src) + updated = tree.visit(transformer) + + # Create the path and directory structure for the new file. + updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) + updated_path.parent.mkdir(parents=True, exist_ok=True) + + # Generate the updated source file at the corresponding path. + with open(updated_path, 'w') as f: + f.write(updated.code) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="""Fix up source that uses the tpu client library. + +The existing sources are NOT overwritten but are copied to output_dir with changes made. + +Note: This tool operates at a best-effort level at converting positional + parameters in client method calls to keyword based parameters. + Cases where it WILL FAIL include + A) * or ** expansion in a method call. + B) Calls via function or method alias (includes free function calls) + C) Indirect or dispatched calls (e.g. the method is looked up dynamically) + + These all constitute false negatives. The tool will also detect false + positives when an API method shares a name with another method. +""") + parser.add_argument( + '-d', + '--input-directory', + required=True, + dest='input_dir', + help='the input directory to walk for python files to fix up', + ) + parser.add_argument( + '-o', + '--output-directory', + required=True, + dest='output_dir', + help='the directory to output files fixed via un-flattening', + ) + args = parser.parse_args() + input_dir = pathlib.Path(args.input_dir) + output_dir = pathlib.Path(args.output_dir) + if not input_dir.is_dir(): + print( + f"input directory '{input_dir}' does not exist or is not a directory", + file=sys.stderr, + ) + sys.exit(-1) + + if not output_dir.is_dir(): + print( + f"output directory '{output_dir}' does not exist or is not a directory", + file=sys.stderr, + ) + sys.exit(-1) + + if os.listdir(output_dir): + print( + f"output directory '{output_dir}' is not empty", + file=sys.stderr, + ) + sys.exit(-1) + + fix_files(input_dir, output_dir) diff --git a/tests/unit/gapic/tpu_v2alpha1/__init__.py b/tests/unit/gapic/tpu_v2alpha1/__init__.py new file mode 100644 index 0000000..4de6597 --- /dev/null +++ b/tests/unit/gapic/tpu_v2alpha1/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/tpu_v2alpha1/test_tpu.py b/tests/unit/gapic/tpu_v2alpha1/test_tpu.py new file mode 100644 index 0000000..d6f8b7b --- /dev/null +++ b/tests/unit/gapic/tpu_v2alpha1/test_tpu.py @@ -0,0 +1,4026 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import mock +import packaging.version + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + + +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.tpu_v2alpha1.services.tpu import TpuAsyncClient +from google.cloud.tpu_v2alpha1.services.tpu import TpuClient +from google.cloud.tpu_v2alpha1.services.tpu import pagers +from google.cloud.tpu_v2alpha1.services.tpu import transports +from google.cloud.tpu_v2alpha1.services.tpu.transports.base import _GOOGLE_AUTH_VERSION +from google.cloud.tpu_v2alpha1.types import cloud_tpu +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +import google.auth + + +# TODO(busunkim): Once google-auth >= 1.25.0 is required transitively +# through google-api-core: +# - Delete the auth "less than" test cases +# - Delete these pytest markers (Make the "greater than or equal to" tests the default). +requires_google_auth_lt_1_25_0 = pytest.mark.skipif( + packaging.version.parse(_GOOGLE_AUTH_VERSION) >= packaging.version.parse("1.25.0"), + reason="This test requires google-auth < 1.25.0", +) +requires_google_auth_gte_1_25_0 = pytest.mark.skipif( + packaging.version.parse(_GOOGLE_AUTH_VERSION) < packaging.version.parse("1.25.0"), + reason="This test requires google-auth >= 1.25.0", +) + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert TpuClient._get_default_mtls_endpoint(None) is None + assert TpuClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert TpuClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert ( + TpuClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + ) + assert ( + TpuClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert TpuClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [TpuClient, TpuAsyncClient,]) +def test_tpu_client_from_service_account_info(client_class): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "tpu.googleapis.com:443" + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.TpuGrpcTransport, "grpc"), + (transports.TpuGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_tpu_client_service_account_always_use_jwt(transport_class, transport_name): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class", [TpuClient, TpuAsyncClient,]) +def test_tpu_client_from_service_account_file(client_class): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == "tpu.googleapis.com:443" + + +def test_tpu_client_get_transport_class(): + transport = TpuClient.get_transport_class() + available_transports = [ + transports.TpuGrpcTransport, + ] + assert transport in available_transports + + transport = TpuClient.get_transport_class("grpc") + assert transport == transports.TpuGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TpuClient, transports.TpuGrpcTransport, "grpc"), + (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +@mock.patch.object(TpuClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TpuClient)) +@mock.patch.object( + TpuAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TpuAsyncClient) +) +def test_tpu_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(TpuClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(TpuClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (TpuClient, transports.TpuGrpcTransport, "grpc", "true"), + (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (TpuClient, transports.TpuGrpcTransport, "grpc", "false"), + (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio", "false"), + ], +) +@mock.patch.object(TpuClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TpuClient)) +@mock.patch.object( + TpuAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TpuAsyncClient) +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_tpu_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TpuClient, transports.TpuGrpcTransport, "grpc"), + (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_tpu_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TpuClient, transports.TpuGrpcTransport, "grpc"), + (TpuAsyncClient, transports.TpuGrpcAsyncIOTransport, "grpc_asyncio"), + ], +) +def test_tpu_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +def test_tpu_client_client_options_from_dict(): + with mock.patch( + "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = TpuClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + ) + + +def test_list_nodes(transport: str = "grpc", request_type=cloud_tpu.ListNodesRequest): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListNodesResponse( + next_page_token="next_page_token_value", unreachable=["unreachable_value"], + ) + response = client.list_nodes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListNodesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListNodesPager) + assert response.next_page_token == "next_page_token_value" + assert response.unreachable == ["unreachable_value"] + + +def test_list_nodes_from_dict(): + test_list_nodes(request_type=dict) + + +def test_list_nodes_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + client.list_nodes() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListNodesRequest() + + +@pytest.mark.asyncio +async def test_list_nodes_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.ListNodesRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListNodesResponse( + next_page_token="next_page_token_value", + unreachable=["unreachable_value"], + ) + ) + response = await client.list_nodes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListNodesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListNodesAsyncPager) + assert response.next_page_token == "next_page_token_value" + assert response.unreachable == ["unreachable_value"] + + +@pytest.mark.asyncio +async def test_list_nodes_async_from_dict(): + await test_list_nodes_async(request_type=dict) + + +def test_list_nodes_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.ListNodesRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + call.return_value = cloud_tpu.ListNodesResponse() + client.list_nodes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_nodes_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.ListNodesRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListNodesResponse() + ) + await client.list_nodes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_nodes_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListNodesResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_nodes(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_nodes_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_nodes( + cloud_tpu.ListNodesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_nodes_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListNodesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListNodesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_nodes(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_nodes_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_nodes( + cloud_tpu.ListNodesRequest(), parent="parent_value", + ) + + +def test_list_nodes_pager(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListNodesResponse( + nodes=[cloud_tpu.Node(), cloud_tpu.Node(), cloud_tpu.Node(),], + next_page_token="abc", + ), + cloud_tpu.ListNodesResponse(nodes=[], next_page_token="def",), + cloud_tpu.ListNodesResponse( + nodes=[cloud_tpu.Node(),], next_page_token="ghi", + ), + cloud_tpu.ListNodesResponse(nodes=[cloud_tpu.Node(), cloud_tpu.Node(),],), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_nodes(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, cloud_tpu.Node) for i in results) + + +def test_list_nodes_pages(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_nodes), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListNodesResponse( + nodes=[cloud_tpu.Node(), cloud_tpu.Node(), cloud_tpu.Node(),], + next_page_token="abc", + ), + cloud_tpu.ListNodesResponse(nodes=[], next_page_token="def",), + cloud_tpu.ListNodesResponse( + nodes=[cloud_tpu.Node(),], next_page_token="ghi", + ), + cloud_tpu.ListNodesResponse(nodes=[cloud_tpu.Node(), cloud_tpu.Node(),],), + RuntimeError, + ) + pages = list(client.list_nodes(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_nodes_async_pager(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_nodes), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListNodesResponse( + nodes=[cloud_tpu.Node(), cloud_tpu.Node(), cloud_tpu.Node(),], + next_page_token="abc", + ), + cloud_tpu.ListNodesResponse(nodes=[], next_page_token="def",), + cloud_tpu.ListNodesResponse( + nodes=[cloud_tpu.Node(),], next_page_token="ghi", + ), + cloud_tpu.ListNodesResponse(nodes=[cloud_tpu.Node(), cloud_tpu.Node(),],), + RuntimeError, + ) + async_pager = await client.list_nodes(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, cloud_tpu.Node) for i in responses) + + +@pytest.mark.asyncio +async def test_list_nodes_async_pages(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_nodes), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListNodesResponse( + nodes=[cloud_tpu.Node(), cloud_tpu.Node(), cloud_tpu.Node(),], + next_page_token="abc", + ), + cloud_tpu.ListNodesResponse(nodes=[], next_page_token="def",), + cloud_tpu.ListNodesResponse( + nodes=[cloud_tpu.Node(),], next_page_token="ghi", + ), + cloud_tpu.ListNodesResponse(nodes=[cloud_tpu.Node(), cloud_tpu.Node(),],), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_nodes(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_node(transport: str = "grpc", request_type=cloud_tpu.GetNodeRequest): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.Node( + name="name_value", + description="description_value", + accelerator_type="accelerator_type_value", + state=cloud_tpu.Node.State.CREATING, + health_description="health_description_value", + runtime_version="runtime_version_value", + cidr_block="cidr_block_value", + health=cloud_tpu.Node.Health.HEALTHY, + tags=["tags_value"], + id=205, + api_version=cloud_tpu.Node.ApiVersion.V1_ALPHA1, + ) + response = client.get_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.Node) + assert response.name == "name_value" + assert response.description == "description_value" + assert response.accelerator_type == "accelerator_type_value" + assert response.state == cloud_tpu.Node.State.CREATING + assert response.health_description == "health_description_value" + assert response.runtime_version == "runtime_version_value" + assert response.cidr_block == "cidr_block_value" + assert response.health == cloud_tpu.Node.Health.HEALTHY + assert response.tags == ["tags_value"] + assert response.id == 205 + assert response.api_version == cloud_tpu.Node.ApiVersion.V1_ALPHA1 + + +def test_get_node_from_dict(): + test_get_node(request_type=dict) + + +def test_get_node_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_node), "__call__") as call: + client.get_node() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetNodeRequest() + + +@pytest.mark.asyncio +async def test_get_node_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.GetNodeRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.Node( + name="name_value", + description="description_value", + accelerator_type="accelerator_type_value", + state=cloud_tpu.Node.State.CREATING, + health_description="health_description_value", + runtime_version="runtime_version_value", + cidr_block="cidr_block_value", + health=cloud_tpu.Node.Health.HEALTHY, + tags=["tags_value"], + id=205, + api_version=cloud_tpu.Node.ApiVersion.V1_ALPHA1, + ) + ) + response = await client.get_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.Node) + assert response.name == "name_value" + assert response.description == "description_value" + assert response.accelerator_type == "accelerator_type_value" + assert response.state == cloud_tpu.Node.State.CREATING + assert response.health_description == "health_description_value" + assert response.runtime_version == "runtime_version_value" + assert response.cidr_block == "cidr_block_value" + assert response.health == cloud_tpu.Node.Health.HEALTHY + assert response.tags == ["tags_value"] + assert response.id == 205 + assert response.api_version == cloud_tpu.Node.ApiVersion.V1_ALPHA1 + + +@pytest.mark.asyncio +async def test_get_node_async_from_dict(): + await test_get_node_async(request_type=dict) + + +def test_get_node_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GetNodeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_node), "__call__") as call: + call.return_value = cloud_tpu.Node() + client.get_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_node_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GetNodeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_node), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(cloud_tpu.Node()) + await client.get_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_node_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.Node() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_node(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_node_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_node( + cloud_tpu.GetNodeRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_node_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.Node() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(cloud_tpu.Node()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_node(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_node_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_node( + cloud_tpu.GetNodeRequest(), name="name_value", + ) + + +def test_create_node(transport: str = "grpc", request_type=cloud_tpu.CreateNodeRequest): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.create_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.CreateNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_node_from_dict(): + test_create_node(request_type=dict) + + +def test_create_node_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_node), "__call__") as call: + client.create_node() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.CreateNodeRequest() + + +@pytest.mark.asyncio +async def test_create_node_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.CreateNodeRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.create_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.CreateNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_node_async_from_dict(): + await test_create_node_async(request_type=dict) + + +def test_create_node_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.CreateNodeRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_node), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.create_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_node_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.CreateNodeRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_node), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.create_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_node_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_node( + parent="parent_value", + node=cloud_tpu.Node(name="name_value"), + node_id="node_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].node == cloud_tpu.Node(name="name_value") + assert args[0].node_id == "node_id_value" + + +def test_create_node_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_node( + cloud_tpu.CreateNodeRequest(), + parent="parent_value", + node=cloud_tpu.Node(name="name_value"), + node_id="node_id_value", + ) + + +@pytest.mark.asyncio +async def test_create_node_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_node( + parent="parent_value", + node=cloud_tpu.Node(name="name_value"), + node_id="node_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].node == cloud_tpu.Node(name="name_value") + assert args[0].node_id == "node_id_value" + + +@pytest.mark.asyncio +async def test_create_node_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_node( + cloud_tpu.CreateNodeRequest(), + parent="parent_value", + node=cloud_tpu.Node(name="name_value"), + node_id="node_id_value", + ) + + +def test_delete_node(transport: str = "grpc", request_type=cloud_tpu.DeleteNodeRequest): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.delete_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.DeleteNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_node_from_dict(): + test_delete_node(request_type=dict) + + +def test_delete_node_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_node), "__call__") as call: + client.delete_node() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.DeleteNodeRequest() + + +@pytest.mark.asyncio +async def test_delete_node_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.DeleteNodeRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.delete_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.DeleteNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_node_async_from_dict(): + await test_delete_node_async(request_type=dict) + + +def test_delete_node_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.DeleteNodeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_node), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.delete_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_node_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.DeleteNodeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_node), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.delete_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_node_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_node(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_node_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_node( + cloud_tpu.DeleteNodeRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_node_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_node(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_node_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_node( + cloud_tpu.DeleteNodeRequest(), name="name_value", + ) + + +def test_stop_node(transport: str = "grpc", request_type=cloud_tpu.StopNodeRequest): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.stop_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.StopNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_stop_node_from_dict(): + test_stop_node(request_type=dict) + + +def test_stop_node_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_node), "__call__") as call: + client.stop_node() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.StopNodeRequest() + + +@pytest.mark.asyncio +async def test_stop_node_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.StopNodeRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.stop_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.StopNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_stop_node_async_from_dict(): + await test_stop_node_async(request_type=dict) + + +def test_stop_node_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.StopNodeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_node), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.stop_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_stop_node_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.StopNodeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.stop_node), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.stop_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_start_node(transport: str = "grpc", request_type=cloud_tpu.StartNodeRequest): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.start_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.start_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.StartNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_start_node_from_dict(): + test_start_node(request_type=dict) + + +def test_start_node_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.start_node), "__call__") as call: + client.start_node() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.StartNodeRequest() + + +@pytest.mark.asyncio +async def test_start_node_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.StartNodeRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.start_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.start_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.StartNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_start_node_async_from_dict(): + await test_start_node_async(request_type=dict) + + +def test_start_node_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.StartNodeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.start_node), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.start_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_start_node_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.StartNodeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.start_node), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.start_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_update_node(transport: str = "grpc", request_type=cloud_tpu.UpdateNodeRequest): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.update_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.UpdateNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_node_from_dict(): + test_update_node(request_type=dict) + + +def test_update_node_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_node), "__call__") as call: + client.update_node() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.UpdateNodeRequest() + + +@pytest.mark.asyncio +async def test_update_node_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.UpdateNodeRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.update_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.UpdateNodeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_node_async_from_dict(): + await test_update_node_async(request_type=dict) + + +def test_update_node_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.UpdateNodeRequest() + + request.node.name = "node.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_node), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.update_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "node.name=node.name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_node_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.UpdateNodeRequest() + + request.node.name = "node.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_node), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.update_node(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "node.name=node.name/value",) in kw["metadata"] + + +def test_update_node_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_node( + node=cloud_tpu.Node(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].node == cloud_tpu.Node(name="name_value") + assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) + + +def test_update_node_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_node( + cloud_tpu.UpdateNodeRequest(), + node=cloud_tpu.Node(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_node_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_node), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_node( + node=cloud_tpu.Node(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].node == cloud_tpu.Node(name="name_value") + assert args[0].update_mask == field_mask_pb2.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_node_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_node( + cloud_tpu.UpdateNodeRequest(), + node=cloud_tpu.Node(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_generate_service_identity( + transport: str = "grpc", request_type=cloud_tpu.GenerateServiceIdentityRequest +): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_service_identity), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.GenerateServiceIdentityResponse() + response = client.generate_service_identity(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GenerateServiceIdentityRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.GenerateServiceIdentityResponse) + + +def test_generate_service_identity_from_dict(): + test_generate_service_identity(request_type=dict) + + +def test_generate_service_identity_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_service_identity), "__call__" + ) as call: + client.generate_service_identity() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GenerateServiceIdentityRequest() + + +@pytest.mark.asyncio +async def test_generate_service_identity_async( + transport: str = "grpc_asyncio", + request_type=cloud_tpu.GenerateServiceIdentityRequest, +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_service_identity), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.GenerateServiceIdentityResponse() + ) + response = await client.generate_service_identity(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GenerateServiceIdentityRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.GenerateServiceIdentityResponse) + + +@pytest.mark.asyncio +async def test_generate_service_identity_async_from_dict(): + await test_generate_service_identity_async(request_type=dict) + + +def test_generate_service_identity_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GenerateServiceIdentityRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_service_identity), "__call__" + ) as call: + call.return_value = cloud_tpu.GenerateServiceIdentityResponse() + client.generate_service_identity(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_generate_service_identity_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GenerateServiceIdentityRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_service_identity), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.GenerateServiceIdentityResponse() + ) + await client.generate_service_identity(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_accelerator_types( + transport: str = "grpc", request_type=cloud_tpu.ListAcceleratorTypesRequest +): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListAcceleratorTypesResponse( + next_page_token="next_page_token_value", unreachable=["unreachable_value"], + ) + response = client.list_accelerator_types(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListAcceleratorTypesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAcceleratorTypesPager) + assert response.next_page_token == "next_page_token_value" + assert response.unreachable == ["unreachable_value"] + + +def test_list_accelerator_types_from_dict(): + test_list_accelerator_types(request_type=dict) + + +def test_list_accelerator_types_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + client.list_accelerator_types() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListAcceleratorTypesRequest() + + +@pytest.mark.asyncio +async def test_list_accelerator_types_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.ListAcceleratorTypesRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListAcceleratorTypesResponse( + next_page_token="next_page_token_value", + unreachable=["unreachable_value"], + ) + ) + response = await client.list_accelerator_types(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListAcceleratorTypesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAcceleratorTypesAsyncPager) + assert response.next_page_token == "next_page_token_value" + assert response.unreachable == ["unreachable_value"] + + +@pytest.mark.asyncio +async def test_list_accelerator_types_async_from_dict(): + await test_list_accelerator_types_async(request_type=dict) + + +def test_list_accelerator_types_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.ListAcceleratorTypesRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + call.return_value = cloud_tpu.ListAcceleratorTypesResponse() + client.list_accelerator_types(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_accelerator_types_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.ListAcceleratorTypesRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListAcceleratorTypesResponse() + ) + await client.list_accelerator_types(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_accelerator_types_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListAcceleratorTypesResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_accelerator_types(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_accelerator_types_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_accelerator_types( + cloud_tpu.ListAcceleratorTypesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_accelerator_types_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListAcceleratorTypesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListAcceleratorTypesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_accelerator_types(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_accelerator_types_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_accelerator_types( + cloud_tpu.ListAcceleratorTypesRequest(), parent="parent_value", + ) + + +def test_list_accelerator_types_pager(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[ + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + ], + next_page_token="abc", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[], next_page_token="def", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[cloud_tpu.AcceleratorType(),], next_page_token="ghi", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[ + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_accelerator_types(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, cloud_tpu.AcceleratorType) for i in results) + + +def test_list_accelerator_types_pages(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[ + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + ], + next_page_token="abc", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[], next_page_token="def", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[cloud_tpu.AcceleratorType(),], next_page_token="ghi", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[ + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + ], + ), + RuntimeError, + ) + pages = list(client.list_accelerator_types(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_accelerator_types_async_pager(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[ + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + ], + next_page_token="abc", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[], next_page_token="def", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[cloud_tpu.AcceleratorType(),], next_page_token="ghi", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[ + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_accelerator_types(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, cloud_tpu.AcceleratorType) for i in responses) + + +@pytest.mark.asyncio +async def test_list_accelerator_types_async_pages(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_accelerator_types), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[ + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + ], + next_page_token="abc", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[], next_page_token="def", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[cloud_tpu.AcceleratorType(),], next_page_token="ghi", + ), + cloud_tpu.ListAcceleratorTypesResponse( + accelerator_types=[ + cloud_tpu.AcceleratorType(), + cloud_tpu.AcceleratorType(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_accelerator_types(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_accelerator_type( + transport: str = "grpc", request_type=cloud_tpu.GetAcceleratorTypeRequest +): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_accelerator_type), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.AcceleratorType( + name="name_value", type_="type__value", + ) + response = client.get_accelerator_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetAcceleratorTypeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.AcceleratorType) + assert response.name == "name_value" + assert response.type_ == "type__value" + + +def test_get_accelerator_type_from_dict(): + test_get_accelerator_type(request_type=dict) + + +def test_get_accelerator_type_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_accelerator_type), "__call__" + ) as call: + client.get_accelerator_type() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetAcceleratorTypeRequest() + + +@pytest.mark.asyncio +async def test_get_accelerator_type_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.GetAcceleratorTypeRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_accelerator_type), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.AcceleratorType(name="name_value", type_="type__value",) + ) + response = await client.get_accelerator_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetAcceleratorTypeRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.AcceleratorType) + assert response.name == "name_value" + assert response.type_ == "type__value" + + +@pytest.mark.asyncio +async def test_get_accelerator_type_async_from_dict(): + await test_get_accelerator_type_async(request_type=dict) + + +def test_get_accelerator_type_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GetAcceleratorTypeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_accelerator_type), "__call__" + ) as call: + call.return_value = cloud_tpu.AcceleratorType() + client.get_accelerator_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_accelerator_type_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GetAcceleratorTypeRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_accelerator_type), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.AcceleratorType() + ) + await client.get_accelerator_type(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_accelerator_type_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_accelerator_type), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.AcceleratorType() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_accelerator_type(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_accelerator_type_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_accelerator_type( + cloud_tpu.GetAcceleratorTypeRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_accelerator_type_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_accelerator_type), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.AcceleratorType() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.AcceleratorType() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_accelerator_type(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_accelerator_type_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_accelerator_type( + cloud_tpu.GetAcceleratorTypeRequest(), name="name_value", + ) + + +def test_list_runtime_versions( + transport: str = "grpc", request_type=cloud_tpu.ListRuntimeVersionsRequest +): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListRuntimeVersionsResponse( + next_page_token="next_page_token_value", unreachable=["unreachable_value"], + ) + response = client.list_runtime_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListRuntimeVersionsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListRuntimeVersionsPager) + assert response.next_page_token == "next_page_token_value" + assert response.unreachable == ["unreachable_value"] + + +def test_list_runtime_versions_from_dict(): + test_list_runtime_versions(request_type=dict) + + +def test_list_runtime_versions_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + client.list_runtime_versions() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListRuntimeVersionsRequest() + + +@pytest.mark.asyncio +async def test_list_runtime_versions_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.ListRuntimeVersionsRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListRuntimeVersionsResponse( + next_page_token="next_page_token_value", + unreachable=["unreachable_value"], + ) + ) + response = await client.list_runtime_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.ListRuntimeVersionsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListRuntimeVersionsAsyncPager) + assert response.next_page_token == "next_page_token_value" + assert response.unreachable == ["unreachable_value"] + + +@pytest.mark.asyncio +async def test_list_runtime_versions_async_from_dict(): + await test_list_runtime_versions_async(request_type=dict) + + +def test_list_runtime_versions_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.ListRuntimeVersionsRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + call.return_value = cloud_tpu.ListRuntimeVersionsResponse() + client.list_runtime_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_runtime_versions_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.ListRuntimeVersionsRequest() + + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListRuntimeVersionsResponse() + ) + await client.list_runtime_versions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_runtime_versions_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListRuntimeVersionsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_runtime_versions(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_runtime_versions_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_runtime_versions( + cloud_tpu.ListRuntimeVersionsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_runtime_versions_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.ListRuntimeVersionsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.ListRuntimeVersionsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_runtime_versions(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_runtime_versions_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_runtime_versions( + cloud_tpu.ListRuntimeVersionsRequest(), parent="parent_value", + ) + + +def test_list_runtime_versions_pager(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[ + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + ], + next_page_token="abc", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[], next_page_token="def", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[cloud_tpu.RuntimeVersion(),], next_page_token="ghi", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[ + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_runtime_versions(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, cloud_tpu.RuntimeVersion) for i in results) + + +def test_list_runtime_versions_pages(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[ + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + ], + next_page_token="abc", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[], next_page_token="def", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[cloud_tpu.RuntimeVersion(),], next_page_token="ghi", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[ + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + ], + ), + RuntimeError, + ) + pages = list(client.list_runtime_versions(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_runtime_versions_async_pager(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[ + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + ], + next_page_token="abc", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[], next_page_token="def", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[cloud_tpu.RuntimeVersion(),], next_page_token="ghi", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[ + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_runtime_versions(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, cloud_tpu.RuntimeVersion) for i in responses) + + +@pytest.mark.asyncio +async def test_list_runtime_versions_async_pages(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_runtime_versions), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[ + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + ], + next_page_token="abc", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[], next_page_token="def", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[cloud_tpu.RuntimeVersion(),], next_page_token="ghi", + ), + cloud_tpu.ListRuntimeVersionsResponse( + runtime_versions=[ + cloud_tpu.RuntimeVersion(), + cloud_tpu.RuntimeVersion(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_runtime_versions(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_runtime_version( + transport: str = "grpc", request_type=cloud_tpu.GetRuntimeVersionRequest +): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_runtime_version), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.RuntimeVersion( + name="name_value", version="version_value", + ) + response = client.get_runtime_version(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetRuntimeVersionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.RuntimeVersion) + assert response.name == "name_value" + assert response.version == "version_value" + + +def test_get_runtime_version_from_dict(): + test_get_runtime_version(request_type=dict) + + +def test_get_runtime_version_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_runtime_version), "__call__" + ) as call: + client.get_runtime_version() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetRuntimeVersionRequest() + + +@pytest.mark.asyncio +async def test_get_runtime_version_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.GetRuntimeVersionRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_runtime_version), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.RuntimeVersion(name="name_value", version="version_value",) + ) + response = await client.get_runtime_version(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetRuntimeVersionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.RuntimeVersion) + assert response.name == "name_value" + assert response.version == "version_value" + + +@pytest.mark.asyncio +async def test_get_runtime_version_async_from_dict(): + await test_get_runtime_version_async(request_type=dict) + + +def test_get_runtime_version_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GetRuntimeVersionRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_runtime_version), "__call__" + ) as call: + call.return_value = cloud_tpu.RuntimeVersion() + client.get_runtime_version(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_runtime_version_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GetRuntimeVersionRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_runtime_version), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.RuntimeVersion() + ) + await client.get_runtime_version(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_runtime_version_flattened(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_runtime_version), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.RuntimeVersion() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_runtime_version(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_runtime_version_flattened_error(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_runtime_version( + cloud_tpu.GetRuntimeVersionRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_runtime_version_flattened_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_runtime_version), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.RuntimeVersion() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.RuntimeVersion() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_runtime_version(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_runtime_version_flattened_error_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_runtime_version( + cloud_tpu.GetRuntimeVersionRequest(), name="name_value", + ) + + +def test_get_guest_attributes( + transport: str = "grpc", request_type=cloud_tpu.GetGuestAttributesRequest +): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_guest_attributes), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cloud_tpu.GetGuestAttributesResponse() + response = client.get_guest_attributes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetGuestAttributesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.GetGuestAttributesResponse) + + +def test_get_guest_attributes_from_dict(): + test_get_guest_attributes(request_type=dict) + + +def test_get_guest_attributes_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_guest_attributes), "__call__" + ) as call: + client.get_guest_attributes() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetGuestAttributesRequest() + + +@pytest.mark.asyncio +async def test_get_guest_attributes_async( + transport: str = "grpc_asyncio", request_type=cloud_tpu.GetGuestAttributesRequest +): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_guest_attributes), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.GetGuestAttributesResponse() + ) + response = await client.get_guest_attributes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == cloud_tpu.GetGuestAttributesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, cloud_tpu.GetGuestAttributesResponse) + + +@pytest.mark.asyncio +async def test_get_guest_attributes_async_from_dict(): + await test_get_guest_attributes_async(request_type=dict) + + +def test_get_guest_attributes_field_headers(): + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GetGuestAttributesRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_guest_attributes), "__call__" + ) as call: + call.return_value = cloud_tpu.GetGuestAttributesResponse() + client.get_guest_attributes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_guest_attributes_field_headers_async(): + client = TpuAsyncClient(credentials=ga_credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cloud_tpu.GetGuestAttributesRequest() + + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_guest_attributes), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cloud_tpu.GetGuestAttributesResponse() + ) + await client.get_guest_attributes(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.TpuGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.TpuGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TpuClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.TpuGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TpuClient(client_options={"scopes": ["1", "2"]}, transport=transport,) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.TpuGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = TpuClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.TpuGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.TpuGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport,], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = TpuClient(credentials=ga_credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.TpuGrpcTransport,) + + +def test_tpu_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.TpuTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_tpu_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.TpuTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "list_nodes", + "get_node", + "create_node", + "delete_node", + "stop_node", + "start_node", + "update_node", + "generate_service_identity", + "list_accelerator_types", + "get_accelerator_type", + "list_runtime_versions", + "get_runtime_version", + "get_guest_attributes", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +@requires_google_auth_gte_1_25_0 +def test_tpu_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TpuTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@requires_google_auth_lt_1_25_0 +def test_tpu_base_transport_with_credentials_file_old_google_auth(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TpuTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_tpu_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.cloud.tpu_v2alpha1.services.tpu.transports.TpuTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TpuTransport() + adc.assert_called_once() + + +@requires_google_auth_gte_1_25_0 +def test_tpu_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + TpuClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +@requires_google_auth_lt_1_25_0 +def test_tpu_auth_adc_old_google_auth(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + TpuClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport,], +) +@requires_google_auth_gte_1_25_0 +def test_tpu_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport,], +) +@requires_google_auth_lt_1_25_0 +def test_tpu_transport_auth_adc_old_google_auth(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus") + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.TpuGrpcTransport, grpc_helpers), + (transports.TpuGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_tpu_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "tpu.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=("https://www.googleapis.com/auth/cloud-platform",), + scopes=["1", "2"], + default_host="tpu.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport] +) +def test_tpu_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_tpu_host_no_port(): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint="tpu.googleapis.com"), + ) + assert client.transport._host == "tpu.googleapis.com:443" + + +def test_tpu_host_with_port(): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="tpu.googleapis.com:8000" + ), + ) + assert client.transport._host == "tpu.googleapis.com:8000" + + +def test_tpu_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TpuGrpcTransport(host="squid.clam.whelk", channel=channel,) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_tpu_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TpuGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport] +) +def test_tpu_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", [transports.TpuGrpcTransport, transports.TpuGrpcAsyncIOTransport] +) +def test_tpu_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_tpu_grpc_lro_client(): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_tpu_grpc_lro_async_client(): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_accelerator_type_path(): + project = "squid" + location = "clam" + accelerator_type = "whelk" + expected = "projects/{project}/locations/{location}/acceleratorTypes/{accelerator_type}".format( + project=project, location=location, accelerator_type=accelerator_type, + ) + actual = TpuClient.accelerator_type_path(project, location, accelerator_type) + assert expected == actual + + +def test_parse_accelerator_type_path(): + expected = { + "project": "octopus", + "location": "oyster", + "accelerator_type": "nudibranch", + } + path = TpuClient.accelerator_type_path(**expected) + + # Check that the path construction is reversible. + actual = TpuClient.parse_accelerator_type_path(path) + assert expected == actual + + +def test_node_path(): + project = "cuttlefish" + location = "mussel" + node = "winkle" + expected = "projects/{project}/locations/{location}/nodes/{node}".format( + project=project, location=location, node=node, + ) + actual = TpuClient.node_path(project, location, node) + assert expected == actual + + +def test_parse_node_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "node": "abalone", + } + path = TpuClient.node_path(**expected) + + # Check that the path construction is reversible. + actual = TpuClient.parse_node_path(path) + assert expected == actual + + +def test_runtime_version_path(): + project = "squid" + location = "clam" + runtime_version = "whelk" + expected = "projects/{project}/locations/{location}/runtimeVersions/{runtime_version}".format( + project=project, location=location, runtime_version=runtime_version, + ) + actual = TpuClient.runtime_version_path(project, location, runtime_version) + assert expected == actual + + +def test_parse_runtime_version_path(): + expected = { + "project": "octopus", + "location": "oyster", + "runtime_version": "nudibranch", + } + path = TpuClient.runtime_version_path(**expected) + + # Check that the path construction is reversible. + actual = TpuClient.parse_runtime_version_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = TpuClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = TpuClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = TpuClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + expected = "folders/{folder}".format(folder=folder,) + actual = TpuClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = TpuClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = TpuClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + expected = "organizations/{organization}".format(organization=organization,) + actual = TpuClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = TpuClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = TpuClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + expected = "projects/{project}".format(project=project,) + actual = TpuClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = TpuClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = TpuClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = TpuClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = TpuClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = TpuClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.TpuTransport, "_prep_wrapped_messages") as prep: + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.TpuTransport, "_prep_wrapped_messages") as prep: + transport_class = TpuClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = TpuAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "grpc", + ] + for transport in transports: + client = TpuClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called()