From 4ddc426a6b4b8cd319fa885e363c94b35ef777d9 Mon Sep 17 00:00:00 2001 From: Yu-Han Liu Date: Wed, 23 Sep 2020 15:18:26 -0700 Subject: [PATCH] feat: regenerate v1beta1 (#4) --- docs/aiplatform_v1beta1/services.rst | 24 +- docs/aiplatform_v1beta1/types.rst | 4 +- google/cloud/aiplatform/__init__.py | 29 +- google/cloud/aiplatform_v1beta1/__init__.py | 5 +- .../services/dataset_service/__init__.py | 6 +- .../services/dataset_service/async_client.py | 976 --- .../services/dataset_service/client.py | 433 +- .../services/dataset_service/pagers.py | 235 +- .../dataset_service/transports/__init__.py | 3 - .../dataset_service/transports/base.py | 163 +- .../dataset_service/transports/grpc.py | 153 +- .../transports/grpc_asyncio.py | 530 -- .../services/endpoint_service/__init__.py | 6 +- .../services/endpoint_service/async_client.py | 782 --- .../services/endpoint_service/client.py | 371 +- .../services/endpoint_service/pagers.py | 80 +- .../endpoint_service/transports/__init__.py | 3 - .../endpoint_service/transports/base.py | 125 +- .../endpoint_service/transports/grpc.py | 153 +- .../transports/grpc_asyncio.py | 449 -- .../services/job_service/__init__.py | 6 +- .../services/job_service/async_client.py | 1803 ----- .../services/job_service/client.py | 772 +-- .../services/job_service/pagers.py | 319 +- .../job_service/transports/__init__.py | 3 - .../services/job_service/transports/base.py | 284 +- .../services/job_service/transports/grpc.py | 153 +- .../job_service/transports/grpc_asyncio.py | 885 --- .../services/model_service/__init__.py | 6 +- .../services/model_service/async_client.py | 963 --- .../services/model_service/client.py | 431 +- .../services/model_service/pagers.py | 239 +- .../model_service/transports/__init__.py | 3 - .../services/model_service/transports/base.py | 168 +- .../services/model_service/transports/grpc.py | 153 +- .../model_service/transports/grpc_asyncio.py | 534 -- .../services/pipeline_service/__init__.py | 6 +- .../services/pipeline_service/async_client.py | 551 -- .../services/pipeline_service/client.py | 310 +- .../services/pipeline_service/pagers.py | 84 +- .../pipeline_service/transports/__init__.py | 3 - .../pipeline_service/transports/base.py | 118 +- .../pipeline_service/transports/grpc.py | 153 +- .../transports/grpc_asyncio.py | 413 -- .../services/prediction_service/__init__.py | 6 +- .../prediction_service/async_client.py | 341 - .../services/prediction_service/client.py | 221 +- .../prediction_service/transports/__init__.py | 3 - .../prediction_service/transports/base.py | 80 +- .../prediction_service/transports/grpc.py | 153 +- .../transports/grpc_asyncio.py | 300 - .../specialist_pool_service/__init__.py | 6 +- .../specialist_pool_service/async_client.py | 592 -- .../specialist_pool_service/client.py | 299 +- .../specialist_pool_service/pagers.py | 84 +- .../transports/__init__.py | 3 - .../transports/base.py | 112 +- .../transports/grpc.py | 153 +- .../transports/grpc_asyncio.py | 403 -- .../aiplatform_v1beta1/types/__init__.py | 338 +- .../aiplatform_v1beta1/types/annotation.py | 7 - .../types/annotation_spec.py | 4 - .../types/batch_prediction_job.py | 45 +- .../types/completion_stats.py | 2 - .../aiplatform_v1beta1/types/custom_job.py | 26 +- .../aiplatform_v1beta1/types/data_item.py | 5 - .../types/data_labeling_job.py | 36 +- .../cloud/aiplatform_v1beta1/types/dataset.py | 18 +- .../types/dataset_service.py | 25 - .../types/deployed_model_ref.py | 1 - .../aiplatform_v1beta1/types/endpoint.py | 26 +- .../types/endpoint_service.py | 11 - .../cloud/aiplatform_v1beta1/types/env_var.py | 1 - .../aiplatform_v1beta1/types/explanation.py | 6 - .../types/explanation_metadata.py | 9 +- .../types/hyperparameter_tuning_job.py | 14 - .../aiplatform_v1beta1/types/job_service.py | 25 - .../types/machine_resources.py | 7 - .../cloud/aiplatform_v1beta1/types/model.py | 27 - .../types/model_evaluation.py | 5 - .../types/model_evaluation_slice.py | 9 +- .../aiplatform_v1beta1/types/model_service.py | 22 - .../aiplatform_v1beta1/types/operation.py | 2 - .../types/pipeline_service.py | 6 - .../types/prediction_service.py | 7 - .../types/specialist_pool.py | 4 - .../types/specialist_pool_service.py | 8 - .../cloud/aiplatform_v1beta1/types/study.py | 31 +- .../types/training_pipeline.py | 47 +- .../types/user_action_reference.py | 6 +- mypy.ini | 2 +- noxfile.py | 6 +- synth.metadata | 15 +- synth.py | 51 +- .../test_dataset_service.py | 1140 ++++ .../test_endpoint_service.py | 771 +++ .../aiplatform_v1beta1/test_job_service.py | 2118 ++++++ .../aiplatform_v1beta1/test_model_service.py | 1223 ++++ .../test_pipeline_service.py | 675 ++ .../test_prediction_service.py | 309 + .../test_specialist_pool_service.py | 681 ++ .../unit/gapic/aiplatform_v1beta1/__init__.py | 1 - .../test_dataset_service.py | 3299 ---------- .../test_endpoint_service.py | 2447 ------- .../aiplatform_v1beta1/test_job_service.py | 5808 ----------------- .../aiplatform_v1beta1/test_model_service.py | 3415 ---------- .../test_pipeline_service.py | 2065 ------ .../test_prediction_service.py | 1217 ---- .../test_specialist_pool_service.py | 2121 ------ 109 files changed, 8323 insertions(+), 35467 deletions(-) delete mode 100644 google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/job_service/async_client.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/model_service/async_client.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py delete mode 100644 google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py create mode 100644 tests/unit/aiplatform_v1beta1/test_dataset_service.py create mode 100644 tests/unit/aiplatform_v1beta1/test_endpoint_service.py create mode 100644 tests/unit/aiplatform_v1beta1/test_job_service.py create mode 100644 tests/unit/aiplatform_v1beta1/test_model_service.py create mode 100644 tests/unit/aiplatform_v1beta1/test_pipeline_service.py create mode 100644 tests/unit/aiplatform_v1beta1/test_prediction_service.py create mode 100644 tests/unit/aiplatform_v1beta1/test_specialist_pool_service.py delete mode 100644 tests/unit/gapic/aiplatform_v1beta1/__init__.py delete mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py delete mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py delete mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_job_service.py delete mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_model_service.py delete mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py delete mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py delete mode 100644 tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index f8509ec112..a23ba7675f 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -1,24 +1,6 @@ -Services for Google Cloud Aiplatform v1beta1 API -================================================ +Client for Google Cloud Aiplatform API +====================================== -.. automodule:: google.cloud.aiplatform_v1beta1.services.dataset_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.endpoint_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.job_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.model_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.pipeline_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.prediction_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.specialist_pool_service +.. automodule:: google.cloud.aiplatform_v1beta1 :members: :inherited-members: diff --git a/docs/aiplatform_v1beta1/types.rst b/docs/aiplatform_v1beta1/types.rst index 3f8a7c9d65..df8cb24970 100644 --- a/docs/aiplatform_v1beta1/types.rst +++ b/docs/aiplatform_v1beta1/types.rst @@ -1,5 +1,5 @@ -Types for Google Cloud Aiplatform v1beta1 API -============================================= +Types for Google Cloud Aiplatform API +===================================== .. automodule:: google.cloud.aiplatform_v1beta1.types :members: diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index b55957d37f..0b9376cce7 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -15,43 +15,23 @@ # limitations under the License. # -from google.cloud.aiplatform_v1beta1.services.dataset_service.async_client import ( - DatasetServiceAsyncClient, -) + from google.cloud.aiplatform_v1beta1.services.dataset_service.client import ( DatasetServiceClient, ) -from google.cloud.aiplatform_v1beta1.services.endpoint_service.async_client import ( - EndpointServiceAsyncClient, -) from google.cloud.aiplatform_v1beta1.services.endpoint_service.client import ( EndpointServiceClient, ) -from google.cloud.aiplatform_v1beta1.services.job_service.async_client import ( - JobServiceAsyncClient, -) from google.cloud.aiplatform_v1beta1.services.job_service.client import JobServiceClient -from google.cloud.aiplatform_v1beta1.services.model_service.async_client import ( - ModelServiceAsyncClient, -) from google.cloud.aiplatform_v1beta1.services.model_service.client import ( ModelServiceClient, ) -from google.cloud.aiplatform_v1beta1.services.pipeline_service.async_client import ( - PipelineServiceAsyncClient, -) from google.cloud.aiplatform_v1beta1.services.pipeline_service.client import ( PipelineServiceClient, ) -from google.cloud.aiplatform_v1beta1.services.prediction_service.async_client import ( - PredictionServiceAsyncClient, -) from google.cloud.aiplatform_v1beta1.services.prediction_service.client import ( PredictionServiceClient, ) -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service.async_client import ( - SpecialistPoolServiceAsyncClient, -) from google.cloud.aiplatform_v1beta1.services.specialist_pool_service.client import ( SpecialistPoolServiceClient, ) @@ -352,7 +332,6 @@ "DataItem", "DataLabelingJob", "Dataset", - "DatasetServiceAsyncClient", "DatasetServiceClient", "DedicatedResources", "DeleteBatchPredictionJobRequest", @@ -371,7 +350,6 @@ "DeployedModel", "DeployedModelRef", "Endpoint", - "EndpointServiceAsyncClient", "EndpointServiceClient", "EnvVar", "ExplainRequest", @@ -410,7 +388,6 @@ "ImportDataRequest", "ImportDataResponse", "InputDataConfig", - "JobServiceAsyncClient", "JobServiceClient", "JobState", "ListAnnotationsRequest", @@ -447,9 +424,7 @@ "ModelEvaluation", "ModelEvaluationSlice", "ModelExplanation", - "ModelServiceAsyncClient", "ModelServiceClient", - "PipelineServiceAsyncClient", "PipelineServiceClient", "PipelineState", "Port", @@ -457,7 +432,6 @@ "PredictRequest", "PredictResponse", "PredictSchemata", - "PredictionServiceAsyncClient", "PredictionServiceClient", "PythonPackageSpec", "ResourcesConsumed", @@ -465,7 +439,6 @@ "SampledShapleyAttribution", "Scheduling", "SpecialistPool", - "SpecialistPoolServiceAsyncClient", "SpecialistPoolServiceClient", "StudySpec", "TimestampSplit", diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index f23576efca..b99a73164d 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -15,6 +15,7 @@ # limitations under the License. # + from .services.dataset_service import DatasetServiceClient from .services.endpoint_service import EndpointServiceClient from .services.job_service import JobServiceClient @@ -212,7 +213,6 @@ "DataItem", "DataLabelingJob", "Dataset", - "DatasetServiceClient", "DedicatedResources", "DeleteBatchPredictionJobRequest", "DeleteCustomJobRequest", @@ -319,6 +319,7 @@ "SampledShapleyAttribution", "Scheduling", "SpecialistPool", + "SpecialistPoolServiceClient", "StudySpec", "TimestampSplit", "TrainingConfig", @@ -337,5 +338,5 @@ "UploadModelResponse", "UserActionReference", "WorkerPoolSpec", - "SpecialistPoolServiceClient", + "DatasetServiceClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py index 597f654cb9..8b973db167 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py @@ -16,9 +16,5 @@ # from .client import DatasetServiceClient -from .async_client import DatasetServiceAsyncClient -__all__ = ( - "DatasetServiceClient", - "DatasetServiceAsyncClient", -) +__all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py deleted file mode 100644 index 4e4c72da1b..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ /dev/null @@ -1,976 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from collections import OrderedDict -import functools -import re -from typing import Dict, Sequence, Tuple, Type, Union -import pkg_resources - -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore -from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers -from google.cloud.aiplatform_v1beta1.types import annotation -from google.cloud.aiplatform_v1beta1.types import annotation_spec -from google.cloud.aiplatform_v1beta1.types import data_item -from google.cloud.aiplatform_v1beta1.types import dataset -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.types import dataset_service -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.protobuf import empty_pb2 as empty # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - -from .transports.base import DatasetServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import DatasetServiceGrpcAsyncIOTransport -from .client import DatasetServiceClient - - -class DatasetServiceAsyncClient: - """""" - - _client: DatasetServiceClient - - DEFAULT_ENDPOINT = DatasetServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = DatasetServiceClient.DEFAULT_MTLS_ENDPOINT - - dataset_path = staticmethod(DatasetServiceClient.dataset_path) - parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) - - from_service_account_file = DatasetServiceClient.from_service_account_file - from_service_account_json = from_service_account_file - - get_transport_class = functools.partial( - type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient) - ) - - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the dataset service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.DatasetServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - - self._client = DatasetServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - ) - - async def create_dataset( - self, - request: dataset_service.CreateDatasetRequest = None, - *, - parent: str = None, - dataset: gca_dataset.Dataset = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Creates a Dataset. - - Args: - request (:class:`~.dataset_service.CreateDatasetRequest`): - The request object. Request message for - [DatasetService.CreateDataset][google.cloud.aiplatform.v1beta1.DatasetService.CreateDataset]. - parent (:class:`str`): - Required. The resource name of the Location to create - the Dataset in. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - dataset (:class:`~.gca_dataset.Dataset`): - Required. The Dataset to create. - This corresponds to the ``dataset`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.gca_dataset.Dataset``: A collection of - DataItems and Annotations on them. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, dataset]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.CreateDatasetRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if dataset is not None: - request.dataset = dataset - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_dataset, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - gca_dataset.Dataset, - metadata_type=dataset_service.CreateDatasetOperationMetadata, - ) - - # Done; return the response. - return response - - async def get_dataset( - self, - request: dataset_service.GetDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> dataset.Dataset: - r"""Gets a Dataset. - - Args: - request (:class:`~.dataset_service.GetDatasetRequest`): - The request object. Request message for - [DatasetService.GetDataset][google.cloud.aiplatform.v1beta1.DatasetService.GetDataset]. - name (:class:`str`): - Required. The name of the Dataset - resource. - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.dataset.Dataset: - A collection of DataItems and - Annotations on them. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.GetDatasetRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_dataset, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def update_dataset( - self, - request: dataset_service.UpdateDatasetRequest = None, - *, - dataset: gca_dataset.Dataset = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_dataset.Dataset: - r"""Updates a Dataset. - - Args: - request (:class:`~.dataset_service.UpdateDatasetRequest`): - The request object. Request message for - [DatasetService.UpdateDataset][google.cloud.aiplatform.v1beta1.DatasetService.UpdateDataset]. - dataset (:class:`~.gca_dataset.Dataset`): - Required. The Dataset which replaces - the resource on the server. - This corresponds to the ``dataset`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - update_mask (:class:`~.field_mask.FieldMask`): - Required. The update mask applies to the resource. For - the ``FieldMask`` definition, see - - [FieldMask](https: - //tinyurl.com/dev-google-protobuf#google.protobuf.FieldMask). - Updatable fields: - - - ``display_name`` - - ``description`` - - ``labels`` - This corresponds to the ``update_mask`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.gca_dataset.Dataset: - A collection of DataItems and - Annotations on them. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([dataset, update_mask]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.UpdateDatasetRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if dataset is not None: - request.dataset = dataset - if update_mask is not None: - request.update_mask = update_mask - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_dataset, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("dataset.name", request.dataset.name),) - ), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_datasets( - self, - request: dataset_service.ListDatasetsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDatasetsAsyncPager: - r"""Lists Datasets in a Location. - - Args: - request (:class:`~.dataset_service.ListDatasetsRequest`): - The request object. Request message for - [DatasetService.ListDatasets][google.cloud.aiplatform.v1beta1.DatasetService.ListDatasets]. - parent (:class:`str`): - Required. The name of the Dataset's parent resource. - Format: ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListDatasetsAsyncPager: - Response message for - [DatasetService.ListDatasets][google.cloud.aiplatform.v1beta1.DatasetService.ListDatasets]. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.ListDatasetsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_datasets, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListDatasetsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def delete_dataset( - self, - request: dataset_service.DeleteDatasetRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes a Dataset. - - Args: - request (:class:`~.dataset_service.DeleteDatasetRequest`): - The request object. Request message for - [DatasetService.DeleteDataset][google.cloud.aiplatform.v1beta1.DatasetService.DeleteDataset]. - name (:class:`str`): - Required. The resource name of the Dataset to delete. - Format: - ``projects/{project}/locations/{location}/datasets/{dataset}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.DeleteDatasetRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_dataset, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def import_data( - self, - request: dataset_service.ImportDataRequest = None, - *, - name: str = None, - import_configs: Sequence[dataset.ImportDataConfig] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Imports data into a Dataset. - - Args: - request (:class:`~.dataset_service.ImportDataRequest`): - The request object. Request message for - [DatasetService.ImportData][google.cloud.aiplatform.v1beta1.DatasetService.ImportData]. - name (:class:`str`): - Required. The name of the Dataset resource. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - import_configs (:class:`Sequence[~.dataset.ImportDataConfig]`): - Required. The desired input - locations. The contents of all input - locations will be imported in one batch. - This corresponds to the ``import_configs`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.dataset_service.ImportDataResponse``: - Response message for - [DatasetService.ImportData][google.cloud.aiplatform.v1beta1.DatasetService.ImportData]. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name, import_configs]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.ImportDataRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - if import_configs is not None: - request.import_configs = import_configs - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.import_data, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - dataset_service.ImportDataResponse, - metadata_type=dataset_service.ImportDataOperationMetadata, - ) - - # Done; return the response. - return response - - async def export_data( - self, - request: dataset_service.ExportDataRequest = None, - *, - name: str = None, - export_config: dataset.ExportDataConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Exports data from a Dataset. - - Args: - request (:class:`~.dataset_service.ExportDataRequest`): - The request object. Request message for - [DatasetService.ExportData][google.cloud.aiplatform.v1beta1.DatasetService.ExportData]. - name (:class:`str`): - Required. The name of the Dataset resource. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - export_config (:class:`~.dataset.ExportDataConfig`): - Required. The desired output - location. - This corresponds to the ``export_config`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.dataset_service.ExportDataResponse``: - Response message for - [DatasetService.ExportData][google.cloud.aiplatform.v1beta1.DatasetService.ExportData]. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name, export_config]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.ExportDataRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - if export_config is not None: - request.export_config = export_config - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_data, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - dataset_service.ExportDataResponse, - metadata_type=dataset_service.ExportDataOperationMetadata, - ) - - # Done; return the response. - return response - - async def list_data_items( - self, - request: dataset_service.ListDataItemsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataItemsAsyncPager: - r"""Lists DataItems in a Dataset. - - Args: - request (:class:`~.dataset_service.ListDataItemsRequest`): - The request object. Request message for - [DatasetService.ListDataItems][google.cloud.aiplatform.v1beta1.DatasetService.ListDataItems]. - parent (:class:`str`): - Required. The resource name of the Dataset to list - DataItems from. Format: - ``projects/{project}/locations/{location}/datasets/{dataset}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListDataItemsAsyncPager: - Response message for - [DatasetService.ListDataItems][google.cloud.aiplatform.v1beta1.DatasetService.ListDataItems]. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.ListDataItemsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_data_items, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListDataItemsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def get_annotation_spec( - self, - request: dataset_service.GetAnnotationSpecRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> annotation_spec.AnnotationSpec: - r"""Gets an AnnotationSpec. - - Args: - request (:class:`~.dataset_service.GetAnnotationSpecRequest`): - The request object. Request message for - [DatasetService.GetAnnotationSpec][google.cloud.aiplatform.v1beta1.DatasetService.GetAnnotationSpec]. - name (:class:`str`): - Required. The name of the AnnotationSpec resource. - Format: - - ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.annotation_spec.AnnotationSpec: - Identifies a concept with which - DataItems may be annotated with. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.GetAnnotationSpecRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_annotation_spec, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_annotations( - self, - request: dataset_service.ListAnnotationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListAnnotationsAsyncPager: - r"""Lists Annotations belongs to a dataitem - - Args: - request (:class:`~.dataset_service.ListAnnotationsRequest`): - The request object. Request message for - [DatasetService.ListAnnotations][google.cloud.aiplatform.v1beta1.DatasetService.ListAnnotations]. - parent (:class:`str`): - Required. The resource name of the DataItem to list - Annotations from. Format: - - ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListAnnotationsAsyncPager: - Response message for - [DatasetService.ListAnnotations][google.cloud.aiplatform.v1beta1.DatasetService.ListAnnotations]. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = dataset_service.ListAnnotationsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_annotations, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListAnnotationsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -__all__ = ("DatasetServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index fa6ce59d50..46b78e540d 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -16,24 +16,17 @@ # from collections import OrderedDict -from distutils import util -import os -import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -from google.api_core import client_options as client_options_lib # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore +from google.api_core import operation as ga_operation from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import annotation_spec @@ -47,9 +40,8 @@ from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .transports.base import DatasetServiceTransport from .transports.grpc import DatasetServiceGrpcTransport -from .transports.grpc_asyncio import DatasetServiceGrpcAsyncIOTransport class DatasetServiceClientMeta(type): @@ -64,7 +56,6 @@ class DatasetServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[DatasetServiceTransport]] _transport_registry["grpc"] = DatasetServiceGrpcTransport - _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. @@ -88,38 +79,8 @@ def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport class DatasetServiceClient(metaclass=DatasetServiceClientMeta): """""" - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT + DEFAULT_OPTIONS = ClientOptions.ClientOptions( + api_endpoint="aiplatform.googleapis.com" ) @classmethod @@ -149,22 +110,12 @@ def dataset_path(project: str, location: str, dataset: str,) -> str: project=project, location=location, dataset=dataset, ) - @staticmethod - def parse_dataset_path(path: str) -> Dict[str, str]: - """Parse a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) - return m.groupdict() if m else {} - def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, DatasetServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = None, + client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, ) -> None: """Instantiate the dataset service client. @@ -177,102 +128,26 @@ def __init__( transport (Union[str, ~.DatasetServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. + client_options (ClientOptions): Custom options for the client. """ if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - - # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) - - ssl_credentials = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - is_mtls = True - else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" - ) + client_options = ClientOptions.from_dict(client_options) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, DatasetServiceTransport): - # transport is a DatasetServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: - raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." - ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, - quota_project_id=client_options.quota_project_id, - client_info=client_info, + host=client_options.api_endpoint or "aiplatform.googleapis.com", ) def create_dataset( @@ -322,36 +197,28 @@ def create_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, dataset]) - if request is not None and has_flattened_params: + if request is not None and any([parent, dataset]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.CreateDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.CreateDatasetRequest): - request = dataset_service.CreateDatasetRequest(request) + request = dataset_service.CreateDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if dataset is not None: - request.dataset = dataset + if parent is not None: + request.parent = parent + if dataset is not None: + request.dataset = dataset # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_dataset] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.create_dataset, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -405,29 +272,25 @@ def get_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.GetDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.GetDatasetRequest): - request = dataset_service.GetDatasetRequest(request) + request = dataset_service.GetDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_dataset] + rpc = gapic_v1.method.wrap_method( + self._transport.get_dataset, default_timeout=None, client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -493,38 +356,28 @@ def update_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([dataset, update_mask]) - if request is not None and has_flattened_params: + if request is not None and any([dataset, update_mask]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.UpdateDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.UpdateDatasetRequest): - request = dataset_service.UpdateDatasetRequest(request) + request = dataset_service.UpdateDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if dataset is not None: - request.dataset = dataset - if update_mask is not None: - request.update_mask = update_mask + if dataset is not None: + request.dataset = dataset + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.update_dataset] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("dataset.name", request.dataset.name),) - ), + rpc = gapic_v1.method.wrap_method( + self._transport.update_dataset, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -573,29 +426,27 @@ def list_datasets( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListDatasetsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.ListDatasetsRequest): - request = dataset_service.ListDatasetsRequest(request) + request = dataset_service.ListDatasetsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_datasets] + rpc = gapic_v1.method.wrap_method( + self._transport.list_datasets, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -609,7 +460,7 @@ def list_datasets( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -668,34 +519,26 @@ def delete_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.DeleteDatasetRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.DeleteDatasetRequest): - request = dataset_service.DeleteDatasetRequest(request) + request = dataset_service.DeleteDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_dataset] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_dataset, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -761,36 +604,26 @@ def import_data( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name, import_configs]) - if request is not None and has_flattened_params: + if request is not None and any([name, import_configs]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ImportDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.ImportDataRequest): - request = dataset_service.ImportDataRequest(request) + request = dataset_service.ImportDataRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name - if import_configs is not None: - request.import_configs = import_configs + if name is not None: + request.name = name + if import_configs is not None: + request.import_configs = import_configs # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.import_data] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.import_data, default_timeout=None, client_info=_client_info, ) # Send the request. @@ -855,36 +688,26 @@ def export_data( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name, export_config]) - if request is not None and has_flattened_params: + if request is not None and any([name, export_config]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ExportDataRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.ExportDataRequest): - request = dataset_service.ExportDataRequest(request) + request = dataset_service.ExportDataRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name - if export_config is not None: - request.export_config = export_config + if name is not None: + request.name = name + if export_config is not None: + request.export_config = export_config # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.export_data] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.export_data, default_timeout=None, client_info=_client_info, ) # Send the request. @@ -942,29 +765,27 @@ def list_data_items( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListDataItemsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.ListDataItemsRequest): - request = dataset_service.ListDataItemsRequest(request) + request = dataset_service.ListDataItemsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_data_items] + rpc = gapic_v1.method.wrap_method( + self._transport.list_data_items, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -978,7 +799,7 @@ def list_data_items( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -1023,29 +844,27 @@ def get_annotation_spec( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.GetAnnotationSpecRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.GetAnnotationSpecRequest): - request = dataset_service.GetAnnotationSpecRequest(request) + request = dataset_service.GetAnnotationSpecRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_annotation_spec] + rpc = gapic_v1.method.wrap_method( + self._transport.get_annotation_spec, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1101,29 +920,27 @@ def list_annotations( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a dataset_service.ListAnnotationsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, dataset_service.ListAnnotationsRequest): - request = dataset_service.ListAnnotationsRequest(request) + request = dataset_service.ListAnnotationsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_annotations] + rpc = gapic_v1.method.wrap_method( + self._transport.list_annotations, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1137,7 +954,7 @@ def list_annotations( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -1145,13 +962,13 @@ def list_annotations( try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + _client_info = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + _client_info = gapic_v1.client_info.ClientInfo() __all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index 43c3156caf..0dd2e668cc 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import Any, Callable, Iterable from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import data_item @@ -43,11 +43,11 @@ class ListDatasetsPager: def __init__( self, - method: Callable[..., dataset_service.ListDatasetsResponse], + method: Callable[ + [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse + ], request: dataset_service.ListDatasetsRequest, response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -58,13 +58,10 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListDatasetsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -74,7 +71,7 @@ def pages(self) -> Iterable[dataset_service.ListDatasetsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[dataset.Dataset]: @@ -85,72 +82,6 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) -class ListDatasetsAsyncPager: - """A pager for iterating through ``list_datasets`` requests. - - This class thinly wraps an initial - :class:`~.dataset_service.ListDatasetsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``datasets`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListDatasets`` requests and continue to iterate - through the ``datasets`` field on the - corresponding responses. - - All the usual :class:`~.dataset_service.ListDatasetsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], - request: dataset_service.ListDatasetsRequest, - response: dataset_service.ListDatasetsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.dataset_service.ListDatasetsRequest`): - The initial request object. - response (:class:`~.dataset_service.ListDatasetsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = dataset_service.ListDatasetsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[dataset_service.ListDatasetsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[dataset.Dataset]: - async def async_generator(): - async for page in self.pages: - for response in page.datasets: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - class ListDataItemsPager: """A pager for iterating through ``list_data_items`` requests. @@ -171,11 +102,12 @@ class ListDataItemsPager: def __init__( self, - method: Callable[..., dataset_service.ListDataItemsResponse], + method: Callable[ + [dataset_service.ListDataItemsRequest], + dataset_service.ListDataItemsResponse, + ], request: dataset_service.ListDataItemsRequest, response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -186,13 +118,10 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListDataItemsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDataItemsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -202,7 +131,7 @@ def pages(self) -> Iterable[dataset_service.ListDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[data_item.DataItem]: @@ -213,72 +142,6 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) -class ListDataItemsAsyncPager: - """A pager for iterating through ``list_data_items`` requests. - - This class thinly wraps an initial - :class:`~.dataset_service.ListDataItemsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``data_items`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListDataItems`` requests and continue to iterate - through the ``data_items`` field on the - corresponding responses. - - All the usual :class:`~.dataset_service.ListDataItemsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], - request: dataset_service.ListDataItemsRequest, - response: dataset_service.ListDataItemsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.dataset_service.ListDataItemsRequest`): - The initial request object. - response (:class:`~.dataset_service.ListDataItemsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = dataset_service.ListDataItemsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[dataset_service.ListDataItemsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[data_item.DataItem]: - async def async_generator(): - async for page in self.pages: - for response in page.data_items: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - class ListAnnotationsPager: """A pager for iterating through ``list_annotations`` requests. @@ -299,11 +162,12 @@ class ListAnnotationsPager: def __init__( self, - method: Callable[..., dataset_service.ListAnnotationsResponse], + method: Callable[ + [dataset_service.ListAnnotationsRequest], + dataset_service.ListAnnotationsResponse, + ], request: dataset_service.ListAnnotationsRequest, response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -314,13 +178,10 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListAnnotationsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListAnnotationsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -330,7 +191,7 @@ def pages(self) -> Iterable[dataset_service.ListAnnotationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[annotation.Annotation]: @@ -339,69 +200,3 @@ def __iter__(self) -> Iterable[annotation.Annotation]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - -class ListAnnotationsAsyncPager: - """A pager for iterating through ``list_annotations`` requests. - - This class thinly wraps an initial - :class:`~.dataset_service.ListAnnotationsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``annotations`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListAnnotations`` requests and continue to iterate - through the ``annotations`` field on the - corresponding responses. - - All the usual :class:`~.dataset_service.ListAnnotationsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], - request: dataset_service.ListAnnotationsRequest, - response: dataset_service.ListAnnotationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.dataset_service.ListAnnotationsRequest`): - The initial request object. - response (:class:`~.dataset_service.ListAnnotationsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = dataset_service.ListAnnotationsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[dataset_service.ListAnnotationsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[annotation.Annotation]: - async def async_generator(): - async for page in self.pages: - for response in page.annotations: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py index f8496b801c..7f1cb8ca21 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py @@ -20,17 +20,14 @@ from .base import DatasetServiceTransport from .grpc import DatasetServiceGrpcTransport -from .grpc_asyncio import DatasetServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] _transport_registry["grpc"] = DatasetServiceGrpcTransport -_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport __all__ = ( "DatasetServiceTransport", "DatasetServiceGrpcTransport", - "DatasetServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index 56f567959a..f00538959f 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -17,12 +17,8 @@ import abc import typing -import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore +from google import auth from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -33,17 +29,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -class DatasetServiceTransport(abc.ABC): +class DatasetServiceTransport(metaclass=abc.ABCMeta): """Abstract transport class for DatasetService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -53,11 +39,6 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, ) -> None: """Instantiate the transport. @@ -68,17 +49,6 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -87,168 +57,85 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) - - if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) - - elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.create_dataset: gapic_v1.method.wrap_method( - self.create_dataset, default_timeout=5.0, client_info=client_info, - ), - self.get_dataset: gapic_v1.method.wrap_method( - self.get_dataset, default_timeout=5.0, client_info=client_info, - ), - self.update_dataset: gapic_v1.method.wrap_method( - self.update_dataset, default_timeout=5.0, client_info=client_info, - ), - self.list_datasets: gapic_v1.method.wrap_method( - self.list_datasets, default_timeout=5.0, client_info=client_info, - ), - self.delete_dataset: gapic_v1.method.wrap_method( - self.delete_dataset, default_timeout=5.0, client_info=client_info, - ), - self.import_data: gapic_v1.method.wrap_method( - self.import_data, default_timeout=5.0, client_info=client_info, - ), - self.export_data: gapic_v1.method.wrap_method( - self.export_data, default_timeout=5.0, client_info=client_info, - ), - self.list_data_items: gapic_v1.method.wrap_method( - self.list_data_items, default_timeout=5.0, client_info=client_info, - ), - self.get_annotation_spec: gapic_v1.method.wrap_method( - self.get_annotation_spec, default_timeout=5.0, client_info=client_info, - ), - self.list_annotations: gapic_v1.method.wrap_method( - self.list_annotations, default_timeout=5.0, client_info=client_info, - ), - } - @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError() + raise NotImplementedError @property def create_dataset( self, - ) -> typing.Callable[ - [dataset_service.CreateDatasetRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: + raise NotImplementedError @property def get_dataset( self, - ) -> typing.Callable[ - [dataset_service.GetDatasetRequest], - typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: + raise NotImplementedError @property def update_dataset( self, - ) -> typing.Callable[ - [dataset_service.UpdateDatasetRequest], - typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: + raise NotImplementedError @property def list_datasets( self, ) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], - typing.Union[ - dataset_service.ListDatasetsResponse, - typing.Awaitable[dataset_service.ListDatasetsResponse], - ], + [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse ]: - raise NotImplementedError() + raise NotImplementedError @property def delete_dataset( self, - ) -> typing.Callable[ - [dataset_service.DeleteDatasetRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: + raise NotImplementedError @property def import_data( self, - ) -> typing.Callable[ - [dataset_service.ImportDataRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[dataset_service.ImportDataRequest], operations.Operation]: + raise NotImplementedError @property def export_data( self, - ) -> typing.Callable[ - [dataset_service.ExportDataRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[dataset_service.ExportDataRequest], operations.Operation]: + raise NotImplementedError @property def list_data_items( self, ) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], - typing.Union[ - dataset_service.ListDataItemsResponse, - typing.Awaitable[dataset_service.ListDataItemsResponse], - ], + [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse ]: - raise NotImplementedError() + raise NotImplementedError @property def get_annotation_spec( self, ) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], - typing.Union[ - annotation_spec.AnnotationSpec, - typing.Awaitable[annotation_spec.AnnotationSpec], - ], + [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec ]: - raise NotImplementedError() + raise NotImplementedError @property def list_annotations( self, ) -> typing.Callable[ [dataset_service.ListAnnotationsRequest], - typing.Union[ - dataset_service.ListAnnotationsResponse, - typing.Awaitable[dataset_service.ListAnnotationsResponse], - ], + dataset_service.ListAnnotationsResponse, ]: - raise NotImplementedError() + raise NotImplementedError __all__ = ("DatasetServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 779b062b57..8110a97c2c 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -15,15 +15,11 @@ # limitations under the License. # -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -33,7 +29,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .base import DatasetServiceTransport class DatasetServiceGrpcTransport(DatasetServiceTransport): @@ -47,21 +43,12 @@ class DatasetServiceGrpcTransport(DatasetServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + channel: grpc.Channel = None ) -> None: """Instantiate the transport. @@ -73,119 +60,28 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ + # Sanity check: Ensure that channel and credentials are not both + # provided. if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - + # Run the base constructor. + super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # If a channel was explicitly provided, set it. + if channel: + self._grpc_channel = channel @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, + **kwargs ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -195,31 +91,13 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, + host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs ) @property @@ -229,6 +107,13 @@ def grpc_channel(self) -> grpc.Channel: This property caches on the instance; repeated calls return the same channel. """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, "_grpc_channel"): + self._grpc_channel = self.create_channel( + self._host, credentials=self._credentials, + ) + # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py deleted file mode 100644 index c0067cb997..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,530 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple - -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.aiplatform_v1beta1.types import annotation_spec -from google.cloud.aiplatform_v1beta1.types import dataset -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.types import dataset_service -from google.longrunning import operations_pb2 as operations # type: ignore - -from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import DatasetServiceGrpcTransport - - -class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): - """gRPC AsyncIO backend transport for DatasetService. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - address (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - scopes = scopes or cls.AUTH_SCOPES - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, - ) - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. - credentials = False - - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def operations_client(self) -> operations_v1.OperationsAsyncClient: - """Create the client designed to process long-running operations. - - This property caches on the instance; repeated calls return the same - client. - """ - # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( - self.grpc_channel - ) - - # Return the client from cache. - return self.__dict__["operations_client"] - - @property - def create_dataset( - self, - ) -> Callable[ - [dataset_service.CreateDatasetRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the create dataset method over gRPC. - - Creates a Dataset. - - Returns: - Callable[[~.CreateDatasetRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "create_dataset" not in self._stubs: - self._stubs["create_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", - request_serializer=dataset_service.CreateDatasetRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["create_dataset"] - - @property - def get_dataset( - self, - ) -> Callable[[dataset_service.GetDatasetRequest], Awaitable[dataset.Dataset]]: - r"""Return a callable for the get dataset method over gRPC. - - Gets a Dataset. - - Returns: - Callable[[~.GetDatasetRequest], - Awaitable[~.Dataset]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_dataset" not in self._stubs: - self._stubs["get_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", - request_serializer=dataset_service.GetDatasetRequest.serialize, - response_deserializer=dataset.Dataset.deserialize, - ) - return self._stubs["get_dataset"] - - @property - def update_dataset( - self, - ) -> Callable[ - [dataset_service.UpdateDatasetRequest], Awaitable[gca_dataset.Dataset] - ]: - r"""Return a callable for the update dataset method over gRPC. - - Updates a Dataset. - - Returns: - Callable[[~.UpdateDatasetRequest], - Awaitable[~.Dataset]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "update_dataset" not in self._stubs: - self._stubs["update_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", - request_serializer=dataset_service.UpdateDatasetRequest.serialize, - response_deserializer=gca_dataset.Dataset.deserialize, - ) - return self._stubs["update_dataset"] - - @property - def list_datasets( - self, - ) -> Callable[ - [dataset_service.ListDatasetsRequest], - Awaitable[dataset_service.ListDatasetsResponse], - ]: - r"""Return a callable for the list datasets method over gRPC. - - Lists Datasets in a Location. - - Returns: - Callable[[~.ListDatasetsRequest], - Awaitable[~.ListDatasetsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_datasets" not in self._stubs: - self._stubs["list_datasets"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", - request_serializer=dataset_service.ListDatasetsRequest.serialize, - response_deserializer=dataset_service.ListDatasetsResponse.deserialize, - ) - return self._stubs["list_datasets"] - - @property - def delete_dataset( - self, - ) -> Callable[ - [dataset_service.DeleteDatasetRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the delete dataset method over gRPC. - - Deletes a Dataset. - - Returns: - Callable[[~.DeleteDatasetRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_dataset" not in self._stubs: - self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", - request_serializer=dataset_service.DeleteDatasetRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_dataset"] - - @property - def import_data( - self, - ) -> Callable[[dataset_service.ImportDataRequest], Awaitable[operations.Operation]]: - r"""Return a callable for the import data method over gRPC. - - Imports data into a Dataset. - - Returns: - Callable[[~.ImportDataRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "import_data" not in self._stubs: - self._stubs["import_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", - request_serializer=dataset_service.ImportDataRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["import_data"] - - @property - def export_data( - self, - ) -> Callable[[dataset_service.ExportDataRequest], Awaitable[operations.Operation]]: - r"""Return a callable for the export data method over gRPC. - - Exports data from a Dataset. - - Returns: - Callable[[~.ExportDataRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "export_data" not in self._stubs: - self._stubs["export_data"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", - request_serializer=dataset_service.ExportDataRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["export_data"] - - @property - def list_data_items( - self, - ) -> Callable[ - [dataset_service.ListDataItemsRequest], - Awaitable[dataset_service.ListDataItemsResponse], - ]: - r"""Return a callable for the list data items method over gRPC. - - Lists DataItems in a Dataset. - - Returns: - Callable[[~.ListDataItemsRequest], - Awaitable[~.ListDataItemsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_data_items" not in self._stubs: - self._stubs["list_data_items"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", - request_serializer=dataset_service.ListDataItemsRequest.serialize, - response_deserializer=dataset_service.ListDataItemsResponse.deserialize, - ) - return self._stubs["list_data_items"] - - @property - def get_annotation_spec( - self, - ) -> Callable[ - [dataset_service.GetAnnotationSpecRequest], - Awaitable[annotation_spec.AnnotationSpec], - ]: - r"""Return a callable for the get annotation spec method over gRPC. - - Gets an AnnotationSpec. - - Returns: - Callable[[~.GetAnnotationSpecRequest], - Awaitable[~.AnnotationSpec]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_annotation_spec" not in self._stubs: - self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", - request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, - response_deserializer=annotation_spec.AnnotationSpec.deserialize, - ) - return self._stubs["get_annotation_spec"] - - @property - def list_annotations( - self, - ) -> Callable[ - [dataset_service.ListAnnotationsRequest], - Awaitable[dataset_service.ListAnnotationsResponse], - ]: - r"""Return a callable for the list annotations method over gRPC. - - Lists Annotations belongs to a dataitem - - Returns: - Callable[[~.ListAnnotationsRequest], - Awaitable[~.ListAnnotationsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_annotations" not in self._stubs: - self._stubs["list_annotations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", - request_serializer=dataset_service.ListAnnotationsRequest.serialize, - response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, - ) - return self._stubs["list_annotations"] - - -__all__ = ("DatasetServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py index 035a5b2388..af0b93f5a8 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py @@ -16,9 +16,5 @@ # from .client import EndpointServiceClient -from .async_client import EndpointServiceAsyncClient -__all__ = ( - "EndpointServiceClient", - "EndpointServiceAsyncClient", -) +__all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py deleted file mode 100644 index 5da5172bf1..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ /dev/null @@ -1,782 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from collections import OrderedDict -import functools -import re -from typing import Dict, Sequence, Tuple, Type, Union -import pkg_resources - -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore -from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers -from google.cloud.aiplatform_v1beta1.types import endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint_service -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.protobuf import empty_pb2 as empty # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - -from .transports.base import EndpointServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import EndpointServiceGrpcAsyncIOTransport -from .client import EndpointServiceClient - - -class EndpointServiceAsyncClient: - """""" - - _client: EndpointServiceClient - - DEFAULT_ENDPOINT = EndpointServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = EndpointServiceClient.DEFAULT_MTLS_ENDPOINT - - endpoint_path = staticmethod(EndpointServiceClient.endpoint_path) - parse_endpoint_path = staticmethod(EndpointServiceClient.parse_endpoint_path) - - from_service_account_file = EndpointServiceClient.from_service_account_file - from_service_account_json = from_service_account_file - - get_transport_class = functools.partial( - type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient) - ) - - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the endpoint service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.EndpointServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - - self._client = EndpointServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - ) - - async def create_endpoint( - self, - request: endpoint_service.CreateEndpointRequest = None, - *, - parent: str = None, - endpoint: gca_endpoint.Endpoint = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Creates an Endpoint. - - Args: - request (:class:`~.endpoint_service.CreateEndpointRequest`): - The request object. Request message for - [EndpointService.CreateEndpoint][google.cloud.aiplatform.v1beta1.EndpointService.CreateEndpoint]. - parent (:class:`str`): - Required. The resource name of the Location to create - the Endpoint in. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - endpoint (:class:`~.gca_endpoint.Endpoint`): - Required. The Endpoint to create. - This corresponds to the ``endpoint`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.gca_endpoint.Endpoint``: Models are deployed - into it, and afterwards Endpoint is called to obtain - predictions and explanations. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, endpoint]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.CreateEndpointRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if endpoint is not None: - request.endpoint = endpoint - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - gca_endpoint.Endpoint, - metadata_type=endpoint_service.CreateEndpointOperationMetadata, - ) - - # Done; return the response. - return response - - async def get_endpoint( - self, - request: endpoint_service.GetEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> endpoint.Endpoint: - r"""Gets an Endpoint. - - Args: - request (:class:`~.endpoint_service.GetEndpointRequest`): - The request object. Request message for - [EndpointService.GetEndpoint][google.cloud.aiplatform.v1beta1.EndpointService.GetEndpoint] - name (:class:`str`): - Required. The name of the Endpoint resource. Format: - ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.endpoint.Endpoint: - Models are deployed into it, and - afterwards Endpoint is called to obtain - predictions and explanations. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.GetEndpointRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_endpoints( - self, - request: endpoint_service.ListEndpointsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListEndpointsAsyncPager: - r"""Lists Endpoints in a Location. - - Args: - request (:class:`~.endpoint_service.ListEndpointsRequest`): - The request object. Request message for - [EndpointService.ListEndpoints][google.cloud.aiplatform.v1beta1.EndpointService.ListEndpoints]. - parent (:class:`str`): - Required. The resource name of the Location from which - to list the Endpoints. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListEndpointsAsyncPager: - Response message for - [EndpointService.ListEndpoints][google.cloud.aiplatform.v1beta1.EndpointService.ListEndpoints]. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.ListEndpointsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_endpoints, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListEndpointsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def update_endpoint( - self, - request: endpoint_service.UpdateEndpointRequest = None, - *, - endpoint: gca_endpoint.Endpoint = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_endpoint.Endpoint: - r"""Updates an Endpoint. - - Args: - request (:class:`~.endpoint_service.UpdateEndpointRequest`): - The request object. Request message for - [EndpointService.UpdateEndpoint][google.cloud.aiplatform.v1beta1.EndpointService.UpdateEndpoint]. - endpoint (:class:`~.gca_endpoint.Endpoint`): - Required. The Endpoint which replaces - the resource on the server. - This corresponds to the ``endpoint`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - update_mask (:class:`~.field_mask.FieldMask`): - Required. The update mask applies to - the resource. - This corresponds to the ``update_mask`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.gca_endpoint.Endpoint: - Models are deployed into it, and - afterwards Endpoint is called to obtain - predictions and explanations. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, update_mask]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.UpdateEndpointRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if endpoint is not None: - request.endpoint = endpoint - if update_mask is not None: - request.update_mask = update_mask - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("endpoint.name", request.endpoint.name),) - ), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def delete_endpoint( - self, - request: endpoint_service.DeleteEndpointRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes an Endpoint. - - Args: - request (:class:`~.endpoint_service.DeleteEndpointRequest`): - The request object. Request message for - [EndpointService.DeleteEndpoint][google.cloud.aiplatform.v1beta1.EndpointService.DeleteEndpoint]. - name (:class:`str`): - Required. The name of the Endpoint resource to be - deleted. Format: - ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.DeleteEndpointRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_endpoint, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def deploy_model( - self, - request: endpoint_service.DeployModelRequest = None, - *, - endpoint: str = None, - deployed_model: gca_endpoint.DeployedModel = None, - traffic_split: Sequence[ - endpoint_service.DeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deploys a Model into this Endpoint, creating a - DeployedModel within it. - - Args: - request (:class:`~.endpoint_service.DeployModelRequest`): - The request object. Request message for - [EndpointService.DeployModel][google.cloud.aiplatform.v1beta1.EndpointService.DeployModel]. - endpoint (:class:`str`): - Required. The name of the Endpoint resource into which - to deploy a Model. Format: - ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - This corresponds to the ``endpoint`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - deployed_model (:class:`~.gca_endpoint.DeployedModel`): - Required. The DeployedModel to be created within the - Endpoint. Note that - [Endpoint.traffic_split][google.cloud.aiplatform.v1beta1.Endpoint.traffic_split] - must be updated for the DeployedModel to start receiving - traffic, either as part of this call, or via - [EndpointService.UpdateEndpoint][google.cloud.aiplatform.v1beta1.EndpointService.UpdateEndpoint]. - This corresponds to the ``deployed_model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - traffic_split (:class:`Sequence[~.endpoint_service.DeployModelRequest.TrafficSplitEntry]`): - A map from a DeployedModel's ID to the percentage of - this Endpoint's traffic that should be forwarded to that - DeployedModel. - - If this field is non-empty, then the Endpoint's - [traffic_split][google.cloud.aiplatform.v1beta1.Endpoint.traffic_split] - will be overwritten with it. To refer to the ID of the - just being deployed Model, a "0" should be used, and the - actual ID of the new DeployedModel will be filled in its - place by this method. The traffic percentage values must - add up to 100. - - If this field is empty, then the Endpoint's - [traffic_split][google.cloud.aiplatform.v1beta1.Endpoint.traffic_split] - is not updated. - This corresponds to the ``traffic_split`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.endpoint_service.DeployModelResponse``: - Response message for - [EndpointService.DeployModel][google.cloud.aiplatform.v1beta1.EndpointService.DeployModel]. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, deployed_model, traffic_split]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.DeployModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if endpoint is not None: - request.endpoint = endpoint - if deployed_model is not None: - request.deployed_model = deployed_model - if traffic_split is not None: - request.traffic_split = traffic_split - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.deploy_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - endpoint_service.DeployModelResponse, - metadata_type=endpoint_service.DeployModelOperationMetadata, - ) - - # Done; return the response. - return response - - async def undeploy_model( - self, - request: endpoint_service.UndeployModelRequest = None, - *, - endpoint: str = None, - deployed_model_id: str = None, - traffic_split: Sequence[ - endpoint_service.UndeployModelRequest.TrafficSplitEntry - ] = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Undeploys a Model from an Endpoint, removing a - DeployedModel from it, and freeing all resources it's - using. - - Args: - request (:class:`~.endpoint_service.UndeployModelRequest`): - The request object. Request message for - [EndpointService.UndeployModel][google.cloud.aiplatform.v1beta1.EndpointService.UndeployModel]. - endpoint (:class:`str`): - Required. The name of the Endpoint resource from which - to undeploy a Model. Format: - ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - This corresponds to the ``endpoint`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - deployed_model_id (:class:`str`): - Required. The ID of the DeployedModel - to be undeployed from the Endpoint. - This corresponds to the ``deployed_model_id`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - traffic_split (:class:`Sequence[~.endpoint_service.UndeployModelRequest.TrafficSplitEntry]`): - If this field is provided, then the Endpoint's - [traffic_split][google.cloud.aiplatform.v1beta1.Endpoint.traffic_split] - will be overwritten with it. If last DeployedModel is - being undeployed from the Endpoint, the - [Endpoint.traffic_split] will always end up empty when - this call returns. A DeployedModel will be successfully - undeployed only if it doesn't have any traffic assigned - to it when this method executes, or if this field - unassigns any traffic to it. - This corresponds to the ``traffic_split`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.endpoint_service.UndeployModelResponse``: - Response message for - [EndpointService.UndeployModel][google.cloud.aiplatform.v1beta1.EndpointService.UndeployModel]. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, deployed_model_id, traffic_split]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = endpoint_service.UndeployModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if endpoint is not None: - request.endpoint = endpoint - if deployed_model_id is not None: - request.deployed_model_id = deployed_model_id - if traffic_split is not None: - request.traffic_split = traffic_split - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.undeploy_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - endpoint_service.UndeployModelResponse, - metadata_type=endpoint_service.UndeployModelOperationMetadata, - ) - - # Done; return the response. - return response - - -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -__all__ = ("EndpointServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 17d9cb84ab..d3135dc885 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -16,24 +16,17 @@ # from collections import OrderedDict -from distutils import util -import os -import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -from google.api_core import client_options as client_options_lib # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore +from google.api_core import operation as ga_operation from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint @@ -43,9 +36,8 @@ from google.protobuf import field_mask_pb2 as field_mask # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .transports.base import EndpointServiceTransport from .transports.grpc import EndpointServiceGrpcTransport -from .transports.grpc_asyncio import EndpointServiceGrpcAsyncIOTransport class EndpointServiceClientMeta(type): @@ -60,7 +52,6 @@ class EndpointServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[EndpointServiceTransport]] _transport_registry["grpc"] = EndpointServiceGrpcTransport - _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. @@ -84,38 +75,8 @@ def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTranspor class EndpointServiceClient(metaclass=EndpointServiceClientMeta): """""" - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT + DEFAULT_OPTIONS = ClientOptions.ClientOptions( + api_endpoint="aiplatform.googleapis.com" ) @classmethod @@ -145,22 +106,12 @@ def endpoint_path(project: str, location: str, endpoint: str,) -> str: project=project, location=location, endpoint=endpoint, ) - @staticmethod - def parse_endpoint_path(path: str) -> Dict[str, str]: - """Parse a endpoint path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", - path, - ) - return m.groupdict() if m else {} - def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, EndpointServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = None, + client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, ) -> None: """Instantiate the endpoint service client. @@ -173,102 +124,26 @@ def __init__( transport (Union[str, ~.EndpointServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. + client_options (ClientOptions): Custom options for the client. """ if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - - # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) - - ssl_credentials = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - is_mtls = True - else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" - ) + client_options = ClientOptions.from_dict(client_options) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, EndpointServiceTransport): - # transport is a EndpointServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: - raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." - ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, - quota_project_id=client_options.quota_project_id, - client_info=client_info, + host=client_options.api_endpoint or "aiplatform.googleapis.com", ) def create_endpoint( @@ -319,36 +194,28 @@ def create_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, endpoint]) - if request is not None and has_flattened_params: + if request is not None and any([parent, endpoint]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.CreateEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, endpoint_service.CreateEndpointRequest): - request = endpoint_service.CreateEndpointRequest(request) + request = endpoint_service.CreateEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if endpoint is not None: - request.endpoint = endpoint + if parent is not None: + request.parent = parent + if endpoint is not None: + request.endpoint = endpoint # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_endpoint] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.create_endpoint, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -403,29 +270,27 @@ def get_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.GetEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, endpoint_service.GetEndpointRequest): - request = endpoint_service.GetEndpointRequest(request) + request = endpoint_service.GetEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_endpoint] + rpc = gapic_v1.method.wrap_method( + self._transport.get_endpoint, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -480,29 +345,27 @@ def list_endpoints( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.ListEndpointsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, endpoint_service.ListEndpointsRequest): - request = endpoint_service.ListEndpointsRequest(request) + request = endpoint_service.ListEndpointsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_endpoints] + rpc = gapic_v1.method.wrap_method( + self._transport.list_endpoints, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -516,7 +379,7 @@ def list_endpoints( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -567,38 +430,28 @@ def update_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, update_mask]) - if request is not None and has_flattened_params: + if request is not None and any([endpoint, update_mask]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.UpdateEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, endpoint_service.UpdateEndpointRequest): - request = endpoint_service.UpdateEndpointRequest(request) + request = endpoint_service.UpdateEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if update_mask is not None: - request.update_mask = update_mask + if endpoint is not None: + request.endpoint = endpoint + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.update_endpoint] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("endpoint.name", request.endpoint.name),) - ), + rpc = gapic_v1.method.wrap_method( + self._transport.update_endpoint, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -660,34 +513,26 @@ def delete_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.DeleteEndpointRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, endpoint_service.DeleteEndpointRequest): - request = endpoint_service.DeleteEndpointRequest(request) + request = endpoint_service.DeleteEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_endpoint] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_endpoint, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -780,38 +625,30 @@ def deploy_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, deployed_model, traffic_split]) - if request is not None and has_flattened_params: + if request is not None and any([endpoint, deployed_model, traffic_split]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.DeployModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, endpoint_service.DeployModelRequest): - request = endpoint_service.DeployModelRequest(request) + request = endpoint_service.DeployModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if deployed_model is not None: - request.deployed_model = deployed_model - if traffic_split is not None: - request.traffic_split = traffic_split + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + if traffic_split is not None: + request.traffic_split = traffic_split # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.deploy_model] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + rpc = gapic_v1.method.wrap_method( + self._transport.deploy_model, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -895,38 +732,30 @@ def undeploy_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) - if request is not None and has_flattened_params: + if request is not None and any([endpoint, deployed_model_id, traffic_split]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a endpoint_service.UndeployModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, endpoint_service.UndeployModelRequest): - request = endpoint_service.UndeployModelRequest(request) + request = endpoint_service.UndeployModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if deployed_model_id is not None: - request.deployed_model_id = deployed_model_id - if traffic_split is not None: - request.traffic_split = traffic_split + if endpoint is not None: + request.endpoint = endpoint + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + if traffic_split is not None: + request.traffic_split = traffic_split # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.undeploy_model] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + rpc = gapic_v1.method.wrap_method( + self._transport.undeploy_model, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -945,13 +774,13 @@ def undeploy_model( try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + _client_info = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + _client_info = gapic_v1.client_info.ClientInfo() __all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index 86320c2178..4c797e56fd 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import Any, Callable, Iterable from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -41,11 +41,12 @@ class ListEndpointsPager: def __init__( self, - method: Callable[..., endpoint_service.ListEndpointsResponse], + method: Callable[ + [endpoint_service.ListEndpointsRequest], + endpoint_service.ListEndpointsResponse, + ], request: endpoint_service.ListEndpointsRequest, response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -56,13 +57,10 @@ def __init__( The initial request object. response (:class:`~.endpoint_service.ListEndpointsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = endpoint_service.ListEndpointsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -72,7 +70,7 @@ def pages(self) -> Iterable[endpoint_service.ListEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[endpoint.Endpoint]: @@ -81,69 +79,3 @@ def __iter__(self) -> Iterable[endpoint.Endpoint]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - -class ListEndpointsAsyncPager: - """A pager for iterating through ``list_endpoints`` requests. - - This class thinly wraps an initial - :class:`~.endpoint_service.ListEndpointsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``endpoints`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListEndpoints`` requests and continue to iterate - through the ``endpoints`` field on the - corresponding responses. - - All the usual :class:`~.endpoint_service.ListEndpointsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], - request: endpoint_service.ListEndpointsRequest, - response: endpoint_service.ListEndpointsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.endpoint_service.ListEndpointsRequest`): - The initial request object. - response (:class:`~.endpoint_service.ListEndpointsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = endpoint_service.ListEndpointsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[endpoint_service.ListEndpointsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[endpoint.Endpoint]: - async def async_generator(): - async for page in self.pages: - for response in page.endpoints: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py index 70a87e920e..62eff450a6 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py @@ -20,17 +20,14 @@ from .base import EndpointServiceTransport from .grpc import EndpointServiceGrpcTransport -from .grpc_asyncio import EndpointServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] _transport_registry["grpc"] = EndpointServiceGrpcTransport -_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport __all__ = ( "EndpointServiceTransport", "EndpointServiceGrpcTransport", - "EndpointServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index e55589de8f..43baa080e0 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -17,12 +17,8 @@ import abc import typing -import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore +from google import auth from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -32,17 +28,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -class EndpointServiceTransport(abc.ABC): +class EndpointServiceTransport(metaclass=abc.ABCMeta): """Abstract transport class for EndpointService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -52,11 +38,6 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, ) -> None: """Instantiate the transport. @@ -67,17 +48,6 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -86,123 +56,66 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) - - if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) - - elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.create_endpoint: gapic_v1.method.wrap_method( - self.create_endpoint, default_timeout=5.0, client_info=client_info, - ), - self.get_endpoint: gapic_v1.method.wrap_method( - self.get_endpoint, default_timeout=5.0, client_info=client_info, - ), - self.list_endpoints: gapic_v1.method.wrap_method( - self.list_endpoints, default_timeout=5.0, client_info=client_info, - ), - self.update_endpoint: gapic_v1.method.wrap_method( - self.update_endpoint, default_timeout=5.0, client_info=client_info, - ), - self.delete_endpoint: gapic_v1.method.wrap_method( - self.delete_endpoint, default_timeout=5.0, client_info=client_info, - ), - self.deploy_model: gapic_v1.method.wrap_method( - self.deploy_model, default_timeout=5.0, client_info=client_info, - ), - self.undeploy_model: gapic_v1.method.wrap_method( - self.undeploy_model, default_timeout=5.0, client_info=client_info, - ), - } - @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError() + raise NotImplementedError @property def create_endpoint( self, ) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [endpoint_service.CreateEndpointRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError @property def get_endpoint( self, - ) -> typing.Callable[ - [endpoint_service.GetEndpointRequest], - typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: + raise NotImplementedError @property def list_endpoints( self, ) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], - typing.Union[ - endpoint_service.ListEndpointsResponse, - typing.Awaitable[endpoint_service.ListEndpointsResponse], - ], + [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse ]: - raise NotImplementedError() + raise NotImplementedError @property def update_endpoint( self, ) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], - typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], + [endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint ]: - raise NotImplementedError() + raise NotImplementedError @property def delete_endpoint( self, ) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [endpoint_service.DeleteEndpointRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError @property def deploy_model( self, - ) -> typing.Callable[ - [endpoint_service.DeployModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[endpoint_service.DeployModelRequest], operations.Operation]: + raise NotImplementedError @property def undeploy_model( self, - ) -> typing.Callable[ - [endpoint_service.UndeployModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: + raise NotImplementedError __all__ = ("EndpointServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index ef023a8749..9bde20f31f 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -15,15 +15,11 @@ # limitations under the License. # -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -32,7 +28,7 @@ from google.cloud.aiplatform_v1beta1.types import endpoint_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .base import EndpointServiceTransport class EndpointServiceGrpcTransport(EndpointServiceTransport): @@ -46,21 +42,12 @@ class EndpointServiceGrpcTransport(EndpointServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + channel: grpc.Channel = None ) -> None: """Instantiate the transport. @@ -72,119 +59,28 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ + # Sanity check: Ensure that channel and credentials are not both + # provided. if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - + # Run the base constructor. + super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # If a channel was explicitly provided, set it. + if channel: + self._grpc_channel = channel @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, + **kwargs ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -194,31 +90,13 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, + host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs ) @property @@ -228,6 +106,13 @@ def grpc_channel(self) -> grpc.Channel: This property caches on the instance; repeated calls return the same channel. """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, "_grpc_channel"): + self._grpc_channel = self.create_channel( + self._host, credentials=self._credentials, + ) + # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py deleted file mode 100644 index 7d743ebb56..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,449 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple - -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.aiplatform_v1beta1.types import endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint_service -from google.longrunning import operations_pb2 as operations # type: ignore - -from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import EndpointServiceGrpcTransport - - -class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): - """gRPC AsyncIO backend transport for EndpointService. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - address (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - scopes = scopes or cls.AUTH_SCOPES - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, - ) - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. - credentials = False - - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def operations_client(self) -> operations_v1.OperationsAsyncClient: - """Create the client designed to process long-running operations. - - This property caches on the instance; repeated calls return the same - client. - """ - # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( - self.grpc_channel - ) - - # Return the client from cache. - return self.__dict__["operations_client"] - - @property - def create_endpoint( - self, - ) -> Callable[ - [endpoint_service.CreateEndpointRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the create endpoint method over gRPC. - - Creates an Endpoint. - - Returns: - Callable[[~.CreateEndpointRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "create_endpoint" not in self._stubs: - self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", - request_serializer=endpoint_service.CreateEndpointRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["create_endpoint"] - - @property - def get_endpoint( - self, - ) -> Callable[[endpoint_service.GetEndpointRequest], Awaitable[endpoint.Endpoint]]: - r"""Return a callable for the get endpoint method over gRPC. - - Gets an Endpoint. - - Returns: - Callable[[~.GetEndpointRequest], - Awaitable[~.Endpoint]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_endpoint" not in self._stubs: - self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", - request_serializer=endpoint_service.GetEndpointRequest.serialize, - response_deserializer=endpoint.Endpoint.deserialize, - ) - return self._stubs["get_endpoint"] - - @property - def list_endpoints( - self, - ) -> Callable[ - [endpoint_service.ListEndpointsRequest], - Awaitable[endpoint_service.ListEndpointsResponse], - ]: - r"""Return a callable for the list endpoints method over gRPC. - - Lists Endpoints in a Location. - - Returns: - Callable[[~.ListEndpointsRequest], - Awaitable[~.ListEndpointsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_endpoints" not in self._stubs: - self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", - request_serializer=endpoint_service.ListEndpointsRequest.serialize, - response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, - ) - return self._stubs["list_endpoints"] - - @property - def update_endpoint( - self, - ) -> Callable[ - [endpoint_service.UpdateEndpointRequest], Awaitable[gca_endpoint.Endpoint] - ]: - r"""Return a callable for the update endpoint method over gRPC. - - Updates an Endpoint. - - Returns: - Callable[[~.UpdateEndpointRequest], - Awaitable[~.Endpoint]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "update_endpoint" not in self._stubs: - self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", - request_serializer=endpoint_service.UpdateEndpointRequest.serialize, - response_deserializer=gca_endpoint.Endpoint.deserialize, - ) - return self._stubs["update_endpoint"] - - @property - def delete_endpoint( - self, - ) -> Callable[ - [endpoint_service.DeleteEndpointRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the delete endpoint method over gRPC. - - Deletes an Endpoint. - - Returns: - Callable[[~.DeleteEndpointRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_endpoint" not in self._stubs: - self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", - request_serializer=endpoint_service.DeleteEndpointRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_endpoint"] - - @property - def deploy_model( - self, - ) -> Callable[ - [endpoint_service.DeployModelRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the deploy model method over gRPC. - - Deploys a Model into this Endpoint, creating a - DeployedModel within it. - - Returns: - Callable[[~.DeployModelRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "deploy_model" not in self._stubs: - self._stubs["deploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", - request_serializer=endpoint_service.DeployModelRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["deploy_model"] - - @property - def undeploy_model( - self, - ) -> Callable[ - [endpoint_service.UndeployModelRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the undeploy model method over gRPC. - - Undeploys a Model from an Endpoint, removing a - DeployedModel from it, and freeing all resources it's - using. - - Returns: - Callable[[~.UndeployModelRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "undeploy_model" not in self._stubs: - self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", - request_serializer=endpoint_service.UndeployModelRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["undeploy_model"] - - -__all__ = ("EndpointServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py index 5f157047f5..bf1248d281 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py @@ -16,9 +16,5 @@ # from .client import JobServiceClient -from .async_client import JobServiceAsyncClient -__all__ = ( - "JobServiceClient", - "JobServiceAsyncClient", -) +__all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py deleted file mode 100644 index ba408ada85..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ /dev/null @@ -1,1803 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from collections import OrderedDict -import functools -import re -from typing import Dict, Sequence, Tuple, Type, Union -import pkg_resources - -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore -from google.cloud.aiplatform_v1beta1.services.job_service import pagers -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) -from google.cloud.aiplatform_v1beta1.types import completion_stats -from google.cloud.aiplatform_v1beta1.types import custom_job -from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) -from google.cloud.aiplatform_v1beta1.types import job_service -from google.cloud.aiplatform_v1beta1.types import job_state -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import study -from google.protobuf import empty_pb2 as empty # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.rpc import status_pb2 as status # type: ignore -from google.type import money_pb2 as money # type: ignore - -from .transports.base import JobServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import JobServiceGrpcAsyncIOTransport -from .client import JobServiceClient - - -class JobServiceAsyncClient: - """A service for creating and managing AI Platform's jobs.""" - - _client: JobServiceClient - - DEFAULT_ENDPOINT = JobServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT - - batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) - parse_batch_prediction_job_path = staticmethod( - JobServiceClient.parse_batch_prediction_job_path - ) - custom_job_path = staticmethod(JobServiceClient.custom_job_path) - parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) - data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) - parse_data_labeling_job_path = staticmethod( - JobServiceClient.parse_data_labeling_job_path - ) - hyperparameter_tuning_job_path = staticmethod( - JobServiceClient.hyperparameter_tuning_job_path - ) - parse_hyperparameter_tuning_job_path = staticmethod( - JobServiceClient.parse_hyperparameter_tuning_job_path - ) - - from_service_account_file = JobServiceClient.from_service_account_file - from_service_account_json = from_service_account_file - - get_transport_class = functools.partial( - type(JobServiceClient).get_transport_class, type(JobServiceClient) - ) - - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, JobServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the job service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.JobServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - - self._client = JobServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - ) - - async def create_custom_job( - self, - request: job_service.CreateCustomJobRequest = None, - *, - parent: str = None, - custom_job: gca_custom_job.CustomJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_custom_job.CustomJob: - r"""Creates a CustomJob. A created CustomJob right away - will be attempted to be run. - - Args: - request (:class:`~.job_service.CreateCustomJobRequest`): - The request object. Request message for - [JobService.CreateCustomJob][google.cloud.aiplatform.v1beta1.JobService.CreateCustomJob]. - parent (:class:`str`): - Required. The resource name of the Location to create - the CustomJob in. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - custom_job (:class:`~.gca_custom_job.CustomJob`): - Required. The CustomJob to create. - This corresponds to the ``custom_job`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.gca_custom_job.CustomJob: - Represents a job that runs custom - workloads such as a Docker container or - a Python package. A CustomJob can have - multiple worker pools and each worker - pool can have its own machine and input - spec. A CustomJob will be cleaned up - once the job enters terminal state - (failed or succeeded). - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, custom_job]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CreateCustomJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if custom_job is not None: - request.custom_job = custom_job - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_custom_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def get_custom_job( - self, - request: job_service.GetCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> custom_job.CustomJob: - r"""Gets a CustomJob. - - Args: - request (:class:`~.job_service.GetCustomJobRequest`): - The request object. Request message for - [JobService.GetCustomJob][google.cloud.aiplatform.v1beta1.JobService.GetCustomJob]. - name (:class:`str`): - Required. The name of the CustomJob resource. Format: - ``projects/{project}/locations/{location}/customJobs/{custom_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.custom_job.CustomJob: - Represents a job that runs custom - workloads such as a Docker container or - a Python package. A CustomJob can have - multiple worker pools and each worker - pool can have its own machine and input - spec. A CustomJob will be cleaned up - once the job enters terminal state - (failed or succeeded). - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.GetCustomJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_custom_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_custom_jobs( - self, - request: job_service.ListCustomJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListCustomJobsAsyncPager: - r"""Lists CustomJobs in a Location. - - Args: - request (:class:`~.job_service.ListCustomJobsRequest`): - The request object. Request message for - [JobService.ListCustomJobs][google.cloud.aiplatform.v1beta1.JobService.ListCustomJobs]. - parent (:class:`str`): - Required. The resource name of the Location to list the - CustomJobs from. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListCustomJobsAsyncPager: - Response message for - [JobService.ListCustomJobs][google.cloud.aiplatform.v1beta1.JobService.ListCustomJobs] - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.ListCustomJobsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_custom_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListCustomJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def delete_custom_job( - self, - request: job_service.DeleteCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes a CustomJob. - - Args: - request (:class:`~.job_service.DeleteCustomJobRequest`): - The request object. Request message for - [JobService.DeleteCustomJob][google.cloud.aiplatform.v1beta1.JobService.DeleteCustomJob]. - name (:class:`str`): - Required. The name of the CustomJob resource to be - deleted. Format: - ``projects/{project}/locations/{location}/customJobs/{custom_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.DeleteCustomJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_custom_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def cancel_custom_job( - self, - request: job_service.CancelCustomJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: - r"""Cancels a CustomJob. Starts asynchronous cancellation on the - CustomJob. The server makes a best effort to cancel the job, but - success is not guaranteed. Clients can use - [JobService.GetCustomJob][google.cloud.aiplatform.v1beta1.JobService.GetCustomJob] - or other methods to check whether the cancellation succeeded or - whether the job completed despite cancellation. On successful - cancellation, the CustomJob is not deleted; instead it becomes a - job with a - [CustomJob.error][google.cloud.aiplatform.v1beta1.CustomJob.error] - value with a [google.rpc.Status.code][google.rpc.Status.code] of - 1, corresponding to ``Code.CANCELLED``, and - [CustomJob.state][google.cloud.aiplatform.v1beta1.CustomJob.state] - is set to ``CANCELLED``. - - Args: - request (:class:`~.job_service.CancelCustomJobRequest`): - The request object. Request message for - [JobService.CancelCustomJob][google.cloud.aiplatform.v1beta1.JobService.CancelCustomJob]. - name (:class:`str`): - Required. The name of the CustomJob to cancel. Format: - ``projects/{project}/locations/{location}/customJobs/{custom_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CancelCustomJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_custom_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_data_labeling_job( - self, - request: job_service.CreateDataLabelingJobRequest = None, - *, - parent: str = None, - data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_data_labeling_job.DataLabelingJob: - r"""Creates a DataLabelingJob. - - Args: - request (:class:`~.job_service.CreateDataLabelingJobRequest`): - The request object. Request message for - [DataLabelingJobService.CreateDataLabelingJob][]. - parent (:class:`str`): - Required. The parent of the DataLabelingJob. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - data_labeling_job (:class:`~.gca_data_labeling_job.DataLabelingJob`): - Required. The DataLabelingJob to - create. - This corresponds to the ``data_labeling_job`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.gca_data_labeling_job.DataLabelingJob: - DataLabelingJob is used to trigger a - human labeling job on unlabeled data - from the following Dataset: - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, data_labeling_job]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CreateDataLabelingJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if data_labeling_job is not None: - request.data_labeling_job = data_labeling_job - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_data_labeling_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def get_data_labeling_job( - self, - request: job_service.GetDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> data_labeling_job.DataLabelingJob: - r"""Gets a DataLabelingJob. - - Args: - request (:class:`~.job_service.GetDataLabelingJobRequest`): - The request object. Request message for - [DataLabelingJobService.GetDataLabelingJob][]. - name (:class:`str`): - Required. The name of the DataLabelingJob. Format: - - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.data_labeling_job.DataLabelingJob: - DataLabelingJob is used to trigger a - human labeling job on unlabeled data - from the following Dataset: - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.GetDataLabelingJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_data_labeling_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_data_labeling_jobs( - self, - request: job_service.ListDataLabelingJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListDataLabelingJobsAsyncPager: - r"""Lists DataLabelingJobs in a Location. - - Args: - request (:class:`~.job_service.ListDataLabelingJobsRequest`): - The request object. Request message for - [DataLabelingJobService.ListDataLabelingJobs][]. - parent (:class:`str`): - Required. The parent of the DataLabelingJob. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListDataLabelingJobsAsyncPager: - Response message for - [JobService.ListDataLabelingJobs][google.cloud.aiplatform.v1beta1.JobService.ListDataLabelingJobs]. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.ListDataLabelingJobsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_data_labeling_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListDataLabelingJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def delete_data_labeling_job( - self, - request: job_service.DeleteDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes a DataLabelingJob. - - Args: - request (:class:`~.job_service.DeleteDataLabelingJobRequest`): - The request object. Request message for - [JobService.DeleteDataLabelingJob][google.cloud.aiplatform.v1beta1.JobService.DeleteDataLabelingJob]. - name (:class:`str`): - Required. The name of the DataLabelingJob to be deleted. - Format: - - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.DeleteDataLabelingJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_data_labeling_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def cancel_data_labeling_job( - self, - request: job_service.CancelDataLabelingJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: - r"""Cancels a DataLabelingJob. Success of cancellation is - not guaranteed. - - Args: - request (:class:`~.job_service.CancelDataLabelingJobRequest`): - The request object. Request message for - [DataLabelingJobService.CancelDataLabelingJob][]. - name (:class:`str`): - Required. The name of the DataLabelingJob. Format: - - ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CancelDataLabelingJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_data_labeling_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_hyperparameter_tuning_job( - self, - request: job_service.CreateHyperparameterTuningJobRequest = None, - *, - parent: str = None, - hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: - r"""Creates a HyperparameterTuningJob - - Args: - request (:class:`~.job_service.CreateHyperparameterTuningJobRequest`): - The request object. Request message for - [JobService.CreateHyperparameterTuningJob][google.cloud.aiplatform.v1beta1.JobService.CreateHyperparameterTuningJob]. - parent (:class:`str`): - Required. The resource name of the Location to create - the HyperparameterTuningJob in. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - hyperparameter_tuning_job (:class:`~.gca_hyperparameter_tuning_job.HyperparameterTuningJob`): - Required. The HyperparameterTuningJob - to create. - This corresponds to the ``hyperparameter_tuning_job`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.gca_hyperparameter_tuning_job.HyperparameterTuningJob: - Represents a HyperparameterTuningJob. - A HyperparameterTuningJob has a Study - specification and multiple CustomJobs - with identical CustomJob specification. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, hyperparameter_tuning_job]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CreateHyperparameterTuningJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if hyperparameter_tuning_job is not None: - request.hyperparameter_tuning_job = hyperparameter_tuning_job - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def get_hyperparameter_tuning_job( - self, - request: job_service.GetHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> hyperparameter_tuning_job.HyperparameterTuningJob: - r"""Gets a HyperparameterTuningJob - - Args: - request (:class:`~.job_service.GetHyperparameterTuningJobRequest`): - The request object. Request message for - [JobService.GetHyperparameterTuningJob][google.cloud.aiplatform.v1beta1.JobService.GetHyperparameterTuningJob]. - name (:class:`str`): - Required. The name of the HyperparameterTuningJob - resource. Format: - - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.hyperparameter_tuning_job.HyperparameterTuningJob: - Represents a HyperparameterTuningJob. - A HyperparameterTuningJob has a Study - specification and multiple CustomJobs - with identical CustomJob specification. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.GetHyperparameterTuningJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_hyperparameter_tuning_jobs( - self, - request: job_service.ListHyperparameterTuningJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListHyperparameterTuningJobsAsyncPager: - r"""Lists HyperparameterTuningJobs in a Location. - - Args: - request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): - The request object. Request message for - [JobService.ListHyperparameterTuningJobs][google.cloud.aiplatform.v1beta1.JobService.ListHyperparameterTuningJobs]. - parent (:class:`str`): - Required. The resource name of the Location to list the - HyperparameterTuningJobs from. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListHyperparameterTuningJobsAsyncPager: - Response message for - [JobService.ListHyperparameterTuningJobs][google.cloud.aiplatform.v1beta1.JobService.ListHyperparameterTuningJobs] - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.ListHyperparameterTuningJobsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_hyperparameter_tuning_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListHyperparameterTuningJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def delete_hyperparameter_tuning_job( - self, - request: job_service.DeleteHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes a HyperparameterTuningJob. - - Args: - request (:class:`~.job_service.DeleteHyperparameterTuningJobRequest`): - The request object. Request message for - [JobService.DeleteHyperparameterTuningJob][google.cloud.aiplatform.v1beta1.JobService.DeleteHyperparameterTuningJob]. - name (:class:`str`): - Required. The name of the HyperparameterTuningJob - resource to be deleted. Format: - - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.DeleteHyperparameterTuningJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def cancel_hyperparameter_tuning_job( - self, - request: job_service.CancelHyperparameterTuningJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: - r"""Cancels a HyperparameterTuningJob. Starts asynchronous - cancellation on the HyperparameterTuningJob. The server makes a - best effort to cancel the job, but success is not guaranteed. - Clients can use - [JobService.GetHyperparameterTuningJob][google.cloud.aiplatform.v1beta1.JobService.GetHyperparameterTuningJob] - or other methods to check whether the cancellation succeeded or - whether the job completed despite cancellation. On successful - cancellation, the HyperparameterTuningJob is not deleted; - instead it becomes a job with a - [HyperparameterTuningJob.error][google.cloud.aiplatform.v1beta1.HyperparameterTuningJob.error] - value with a [google.rpc.Status.code][google.rpc.Status.code] of - 1, corresponding to ``Code.CANCELLED``, and - [HyperparameterTuningJob.state][google.cloud.aiplatform.v1beta1.HyperparameterTuningJob.state] - is set to ``CANCELLED``. - - Args: - request (:class:`~.job_service.CancelHyperparameterTuningJobRequest`): - The request object. Request message for - [JobService.CancelHyperparameterTuningJob][google.cloud.aiplatform.v1beta1.JobService.CancelHyperparameterTuningJob]. - name (:class:`str`): - Required. The name of the HyperparameterTuningJob to - cancel. Format: - - ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CancelHyperparameterTuningJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - async def create_batch_prediction_job( - self, - request: job_service.CreateBatchPredictionJobRequest = None, - *, - parent: str = None, - batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_batch_prediction_job.BatchPredictionJob: - r"""Creates a BatchPredictionJob. A BatchPredictionJob - once created will right away be attempted to start. - - Args: - request (:class:`~.job_service.CreateBatchPredictionJobRequest`): - The request object. Request message for - [JobService.CreateBatchPredictionJob][google.cloud.aiplatform.v1beta1.JobService.CreateBatchPredictionJob]. - parent (:class:`str`): - Required. The resource name of the Location to create - the BatchPredictionJob in. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - batch_prediction_job (:class:`~.gca_batch_prediction_job.BatchPredictionJob`): - Required. The BatchPredictionJob to - create. - This corresponds to the ``batch_prediction_job`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.gca_batch_prediction_job.BatchPredictionJob: - A job that uses a - [Model][google.cloud.aiplatform.v1beta1.BatchPredictionJob.model] - to produce predictions on multiple [input - instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. - If predictions for significant portion of the instances - fail, the job may finish without attempting predictions - for all remaining instances. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, batch_prediction_job]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CreateBatchPredictionJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if batch_prediction_job is not None: - request.batch_prediction_job = batch_prediction_job - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_batch_prediction_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def get_batch_prediction_job( - self, - request: job_service.GetBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> batch_prediction_job.BatchPredictionJob: - r"""Gets a BatchPredictionJob - - Args: - request (:class:`~.job_service.GetBatchPredictionJobRequest`): - The request object. Request message for - [JobService.GetBatchPredictionJob][google.cloud.aiplatform.v1beta1.JobService.GetBatchPredictionJob]. - name (:class:`str`): - Required. The name of the BatchPredictionJob resource. - Format: - - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.batch_prediction_job.BatchPredictionJob: - A job that uses a - [Model][google.cloud.aiplatform.v1beta1.BatchPredictionJob.model] - to produce predictions on multiple [input - instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. - If predictions for significant portion of the instances - fail, the job may finish without attempting predictions - for all remaining instances. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.GetBatchPredictionJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_batch_prediction_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_batch_prediction_jobs( - self, - request: job_service.ListBatchPredictionJobsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListBatchPredictionJobsAsyncPager: - r"""Lists BatchPredictionJobs in a Location. - - Args: - request (:class:`~.job_service.ListBatchPredictionJobsRequest`): - The request object. Request message for - [JobService.ListBatchPredictionJobs][google.cloud.aiplatform.v1beta1.JobService.ListBatchPredictionJobs]. - parent (:class:`str`): - Required. The resource name of the Location to list the - BatchPredictionJobs from. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListBatchPredictionJobsAsyncPager: - Response message for - [JobService.ListBatchPredictionJobs][google.cloud.aiplatform.v1beta1.JobService.ListBatchPredictionJobs] - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.ListBatchPredictionJobsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_batch_prediction_jobs, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListBatchPredictionJobsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def delete_batch_prediction_job( - self, - request: job_service.DeleteBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes a BatchPredictionJob. Can only be called on - jobs that already finished. - - Args: - request (:class:`~.job_service.DeleteBatchPredictionJobRequest`): - The request object. Request message for - [JobService.DeleteBatchPredictionJob][google.cloud.aiplatform.v1beta1.JobService.DeleteBatchPredictionJob]. - name (:class:`str`): - Required. The name of the BatchPredictionJob resource to - be deleted. Format: - - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.DeleteBatchPredictionJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_batch_prediction_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def cancel_batch_prediction_job( - self, - request: job_service.CancelBatchPredictionJobRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: - r"""Cancels a BatchPredictionJob. - - Starts asynchronous cancellation on the BatchPredictionJob. The - server makes the best effort to cancel the job, but success is - not guaranteed. Clients can use - [JobService.GetBatchPredictionJob][google.cloud.aiplatform.v1beta1.JobService.GetBatchPredictionJob] - or other methods to check whether the cancellation succeeded or - whether the job completed despite cancellation. On a successful - cancellation, the BatchPredictionJob is not deleted;instead its - [BatchPredictionJob.state][google.cloud.aiplatform.v1beta1.BatchPredictionJob.state] - is set to ``CANCELLED``. Any files already outputted by the job - are not deleted. - - Args: - request (:class:`~.job_service.CancelBatchPredictionJobRequest`): - The request object. Request message for - [JobService.CancelBatchPredictionJob][google.cloud.aiplatform.v1beta1.JobService.CancelBatchPredictionJob]. - name (:class:`str`): - Required. The name of the BatchPredictionJob to cancel. - Format: - - ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = job_service.CancelBatchPredictionJobRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_batch_prediction_job, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -__all__ = ("JobServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index 76d6c8a94a..9821512326 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -16,24 +16,17 @@ # from collections import OrderedDict -from distutils import util -import os -import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -from google.api_core import client_options as client_options_lib # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore +from google.api_core import operation as ga_operation from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import ( @@ -62,9 +55,8 @@ from google.rpc import status_pb2 as status # type: ignore from google.type import money_pb2 as money # type: ignore -from .transports.base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .transports.base import JobServiceTransport from .transports.grpc import JobServiceGrpcTransport -from .transports.grpc_asyncio import JobServiceGrpcAsyncIOTransport class JobServiceClientMeta(type): @@ -77,7 +69,6 @@ class JobServiceClientMeta(type): _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] _transport_registry["grpc"] = JobServiceGrpcTransport - _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: """Return an appropriate transport class. @@ -101,38 +92,8 @@ def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: class JobServiceClient(metaclass=JobServiceClientMeta): """A service for creating and managing AI Platform's jobs.""" - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT + DEFAULT_OPTIONS = ClientOptions.ClientOptions( + api_endpoint="aiplatform.googleapis.com" ) @classmethod @@ -156,24 +117,15 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file @staticmethod - def batch_prediction_job_path( - project: str, location: str, batch_prediction_job: str, + def hyperparameter_tuning_job_path( + project: str, location: str, hyperparameter_tuning_job: str, ) -> str: - """Return a fully-qualified batch_prediction_job string.""" - return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + """Return a fully-qualified hyperparameter_tuning_job string.""" + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( project=project, location=location, - batch_prediction_job=batch_prediction_job, - ) - - @staticmethod - def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: - """Parse a batch_prediction_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", - path, + hyperparameter_tuning_job=hyperparameter_tuning_job, ) - return m.groupdict() if m else {} @staticmethod def custom_job_path(project: str, location: str, custom_job: str,) -> str: @@ -182,15 +134,6 @@ def custom_job_path(project: str, location: str, custom_job: str,) -> str: project=project, location=location, custom_job=custom_job, ) - @staticmethod - def parse_custom_job_path(path: str) -> Dict[str, str]: - """Parse a custom_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", - path, - ) - return m.groupdict() if m else {} - @staticmethod def data_labeling_job_path( project: str, location: str, data_labeling_job: str, @@ -201,41 +144,22 @@ def data_labeling_job_path( ) @staticmethod - def parse_data_labeling_job_path(path: str) -> Dict[str, str]: - """Parse a data_labeling_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", - path, - ) - return m.groupdict() if m else {} - - @staticmethod - def hyperparameter_tuning_job_path( - project: str, location: str, hyperparameter_tuning_job: str, + def batch_prediction_job_path( + project: str, location: str, batch_prediction_job: str, ) -> str: - """Return a fully-qualified hyperparameter_tuning_job string.""" - return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + """Return a fully-qualified batch_prediction_job string.""" + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( project=project, location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) - - @staticmethod - def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: - """Parse a hyperparameter_tuning_job path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", - path, + batch_prediction_job=batch_prediction_job, ) - return m.groupdict() if m else {} def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, JobServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = None, + client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, ) -> None: """Instantiate the job service client. @@ -248,102 +172,26 @@ def __init__( transport (Union[str, ~.JobServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. + client_options (ClientOptions): Custom options for the client. """ if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - - # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) - - ssl_credentials = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - is_mtls = True - else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" - ) + client_options = ClientOptions.from_dict(client_options) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, JobServiceTransport): - # transport is a JobServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: - raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." - ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, - quota_project_id=client_options.quota_project_id, - client_info=client_info, + host=client_options.api_endpoint or "aiplatform.googleapis.com", ) def create_custom_job( @@ -397,36 +245,28 @@ def create_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, custom_job]) - if request is not None and has_flattened_params: + if request is not None and any([parent, custom_job]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.CreateCustomJobRequest): - request = job_service.CreateCustomJobRequest(request) + request = job_service.CreateCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if custom_job is not None: - request.custom_job = custom_job + if parent is not None: + request.parent = parent + if custom_job is not None: + request.custom_job = custom_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_custom_job] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.create_custom_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -478,29 +318,27 @@ def get_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.GetCustomJobRequest): - request = job_service.GetCustomJobRequest(request) + request = job_service.GetCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_custom_job] + rpc = gapic_v1.method.wrap_method( + self._transport.get_custom_job, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -555,29 +393,27 @@ def list_custom_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListCustomJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.ListCustomJobsRequest): - request = job_service.ListCustomJobsRequest(request) + request = job_service.ListCustomJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_custom_jobs] + rpc = gapic_v1.method.wrap_method( + self._transport.list_custom_jobs, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -591,7 +427,7 @@ def list_custom_jobs( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -650,34 +486,26 @@ def delete_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.DeleteCustomJobRequest): - request = job_service.DeleteCustomJobRequest(request) + request = job_service.DeleteCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_custom_job] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_custom_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -737,34 +565,26 @@ def cancel_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelCustomJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.CancelCustomJobRequest): - request = job_service.CancelCustomJobRequest(request) + request = job_service.CancelCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_custom_job] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.cancel_custom_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -817,36 +637,28 @@ def create_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, data_labeling_job]) - if request is not None and has_flattened_params: + if request is not None and any([parent, data_labeling_job]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.CreateDataLabelingJobRequest): - request = job_service.CreateDataLabelingJobRequest(request) + request = job_service.CreateDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if data_labeling_job is not None: - request.data_labeling_job = data_labeling_job + if parent is not None: + request.parent = parent + if data_labeling_job is not None: + request.data_labeling_job = data_labeling_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_data_labeling_job] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.create_data_labeling_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -894,29 +706,27 @@ def get_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.GetDataLabelingJobRequest): - request = job_service.GetDataLabelingJobRequest(request) + request = job_service.GetDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_data_labeling_job] + rpc = gapic_v1.method.wrap_method( + self._transport.get_data_labeling_job, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -970,29 +780,27 @@ def list_data_labeling_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListDataLabelingJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.ListDataLabelingJobsRequest): - request = job_service.ListDataLabelingJobsRequest(request) + request = job_service.ListDataLabelingJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_data_labeling_jobs] + rpc = gapic_v1.method.wrap_method( + self._transport.list_data_labeling_jobs, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1006,7 +814,7 @@ def list_data_labeling_jobs( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -1066,34 +874,26 @@ def delete_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.DeleteDataLabelingJobRequest): - request = job_service.DeleteDataLabelingJobRequest(request) + request = job_service.DeleteDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_data_labeling_job] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_data_labeling_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -1143,34 +943,26 @@ def cancel_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelDataLabelingJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.CancelDataLabelingJobRequest): - request = job_service.CancelDataLabelingJobRequest(request) + request = job_service.CancelDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_data_labeling_job] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.cancel_data_labeling_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -1225,38 +1017,28 @@ def create_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, hyperparameter_tuning_job]) - if request is not None and has_flattened_params: + if request is not None and any([parent, hyperparameter_tuning_job]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.CreateHyperparameterTuningJobRequest): - request = job_service.CreateHyperparameterTuningJobRequest(request) + request = job_service.CreateHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if hyperparameter_tuning_job is not None: - request.hyperparameter_tuning_job = hyperparameter_tuning_job + if parent is not None: + request.parent = parent + if hyperparameter_tuning_job is not None: + request.hyperparameter_tuning_job = hyperparameter_tuning_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.create_hyperparameter_tuning_job - ] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.create_hyperparameter_tuning_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -1306,31 +1088,27 @@ def get_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.GetHyperparameterTuningJobRequest): - request = job_service.GetHyperparameterTuningJobRequest(request) + request = job_service.GetHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.get_hyperparameter_tuning_job - ] + rpc = gapic_v1.method.wrap_method( + self._transport.get_hyperparameter_tuning_job, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1385,31 +1163,27 @@ def list_hyperparameter_tuning_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListHyperparameterTuningJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.ListHyperparameterTuningJobsRequest): - request = job_service.ListHyperparameterTuningJobsRequest(request) + request = job_service.ListHyperparameterTuningJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_hyperparameter_tuning_jobs - ] + rpc = gapic_v1.method.wrap_method( + self._transport.list_hyperparameter_tuning_jobs, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1423,7 +1197,7 @@ def list_hyperparameter_tuning_jobs( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -1483,36 +1257,26 @@ def delete_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.DeleteHyperparameterTuningJobRequest): - request = job_service.DeleteHyperparameterTuningJobRequest(request) + request = job_service.DeleteHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.delete_hyperparameter_tuning_job - ] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_hyperparameter_tuning_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -1575,36 +1339,26 @@ def cancel_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelHyperparameterTuningJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.CancelHyperparameterTuningJobRequest): - request = job_service.CancelHyperparameterTuningJobRequest(request) + request = job_service.CancelHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.cancel_hyperparameter_tuning_job - ] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.cancel_hyperparameter_tuning_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -1663,38 +1417,28 @@ def create_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, batch_prediction_job]) - if request is not None and has_flattened_params: + if request is not None and any([parent, batch_prediction_job]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CreateBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.CreateBatchPredictionJobRequest): - request = job_service.CreateBatchPredictionJobRequest(request) + request = job_service.CreateBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if batch_prediction_job is not None: - request.batch_prediction_job = batch_prediction_job + if parent is not None: + request.parent = parent + if batch_prediction_job is not None: + request.batch_prediction_job = batch_prediction_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.create_batch_prediction_job - ] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.create_batch_prediction_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -1747,29 +1491,27 @@ def get_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.GetBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.GetBatchPredictionJobRequest): - request = job_service.GetBatchPredictionJobRequest(request) + request = job_service.GetBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_batch_prediction_job] + rpc = gapic_v1.method.wrap_method( + self._transport.get_batch_prediction_job, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1824,31 +1566,27 @@ def list_batch_prediction_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.ListBatchPredictionJobsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.ListBatchPredictionJobsRequest): - request = job_service.ListBatchPredictionJobsRequest(request) + request = job_service.ListBatchPredictionJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_batch_prediction_jobs - ] + rpc = gapic_v1.method.wrap_method( + self._transport.list_batch_prediction_jobs, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1862,7 +1600,7 @@ def list_batch_prediction_jobs( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -1923,36 +1661,26 @@ def delete_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.DeleteBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.DeleteBatchPredictionJobRequest): - request = job_service.DeleteBatchPredictionJobRequest(request) + request = job_service.DeleteBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.delete_batch_prediction_job - ] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_batch_prediction_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -2013,36 +1741,26 @@ def cancel_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a job_service.CancelBatchPredictionJobRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, job_service.CancelBatchPredictionJobRequest): - request = job_service.CancelBatchPredictionJobRequest(request) + request = job_service.CancelBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.cancel_batch_prediction_job - ] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.cancel_batch_prediction_job, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -2052,13 +1770,13 @@ def cancel_batch_prediction_job( try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + _client_info = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + _client_info = gapic_v1.client_info.ClientInfo() __all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index 05e5be73ca..0dd74763fa 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import Any, Callable, Iterable from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job @@ -44,11 +44,11 @@ class ListCustomJobsPager: def __init__( self, - method: Callable[..., job_service.ListCustomJobsResponse], + method: Callable[ + [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse + ], request: job_service.ListCustomJobsRequest, response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -59,13 +59,10 @@ def __init__( The initial request object. response (:class:`~.job_service.ListCustomJobsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = job_service.ListCustomJobsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -75,7 +72,7 @@ def pages(self) -> Iterable[job_service.ListCustomJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[custom_job.CustomJob]: @@ -86,72 +83,6 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) -class ListCustomJobsAsyncPager: - """A pager for iterating through ``list_custom_jobs`` requests. - - This class thinly wraps an initial - :class:`~.job_service.ListCustomJobsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``custom_jobs`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListCustomJobs`` requests and continue to iterate - through the ``custom_jobs`` field on the - corresponding responses. - - All the usual :class:`~.job_service.ListCustomJobsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], - request: job_service.ListCustomJobsRequest, - response: job_service.ListCustomJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.job_service.ListCustomJobsRequest`): - The initial request object. - response (:class:`~.job_service.ListCustomJobsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = job_service.ListCustomJobsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[job_service.ListCustomJobsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[custom_job.CustomJob]: - async def async_generator(): - async for page in self.pages: - for response in page.custom_jobs: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - class ListDataLabelingJobsPager: """A pager for iterating through ``list_data_labeling_jobs`` requests. @@ -172,11 +103,12 @@ class ListDataLabelingJobsPager: def __init__( self, - method: Callable[..., job_service.ListDataLabelingJobsResponse], + method: Callable[ + [job_service.ListDataLabelingJobsRequest], + job_service.ListDataLabelingJobsResponse, + ], request: job_service.ListDataLabelingJobsRequest, response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -187,13 +119,10 @@ def __init__( The initial request object. response (:class:`~.job_service.ListDataLabelingJobsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = job_service.ListDataLabelingJobsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -203,7 +132,7 @@ def pages(self) -> Iterable[job_service.ListDataLabelingJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: @@ -214,72 +143,6 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) -class ListDataLabelingJobsAsyncPager: - """A pager for iterating through ``list_data_labeling_jobs`` requests. - - This class thinly wraps an initial - :class:`~.job_service.ListDataLabelingJobsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``data_labeling_jobs`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListDataLabelingJobs`` requests and continue to iterate - through the ``data_labeling_jobs`` field on the - corresponding responses. - - All the usual :class:`~.job_service.ListDataLabelingJobsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], - request: job_service.ListDataLabelingJobsRequest, - response: job_service.ListDataLabelingJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.job_service.ListDataLabelingJobsRequest`): - The initial request object. - response (:class:`~.job_service.ListDataLabelingJobsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = job_service.ListDataLabelingJobsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[job_service.ListDataLabelingJobsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[data_labeling_job.DataLabelingJob]: - async def async_generator(): - async for page in self.pages: - for response in page.data_labeling_jobs: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - class ListHyperparameterTuningJobsPager: """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. @@ -300,11 +163,12 @@ class ListHyperparameterTuningJobsPager: def __init__( self, - method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], + method: Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + job_service.ListHyperparameterTuningJobsResponse, + ], request: job_service.ListHyperparameterTuningJobsRequest, response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -315,13 +179,10 @@ def __init__( The initial request object. response (:class:`~.job_service.ListHyperparameterTuningJobsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = job_service.ListHyperparameterTuningJobsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -331,7 +192,7 @@ def pages(self) -> Iterable[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob]: @@ -342,78 +203,6 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) -class ListHyperparameterTuningJobsAsyncPager: - """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. - - This class thinly wraps an initial - :class:`~.job_service.ListHyperparameterTuningJobsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``hyperparameter_tuning_jobs`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListHyperparameterTuningJobs`` requests and continue to iterate - through the ``hyperparameter_tuning_jobs`` field on the - corresponding responses. - - All the usual :class:`~.job_service.ListHyperparameterTuningJobsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[job_service.ListHyperparameterTuningJobsResponse] - ], - request: job_service.ListHyperparameterTuningJobsRequest, - response: job_service.ListHyperparameterTuningJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): - The initial request object. - response (:class:`~.job_service.ListHyperparameterTuningJobsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = job_service.ListHyperparameterTuningJobsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages( - self, - ) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__( - self, - ) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: - async def async_generator(): - async for page in self.pages: - for response in page.hyperparameter_tuning_jobs: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - class ListBatchPredictionJobsPager: """A pager for iterating through ``list_batch_prediction_jobs`` requests. @@ -434,11 +223,12 @@ class ListBatchPredictionJobsPager: def __init__( self, - method: Callable[..., job_service.ListBatchPredictionJobsResponse], + method: Callable[ + [job_service.ListBatchPredictionJobsRequest], + job_service.ListBatchPredictionJobsResponse, + ], request: job_service.ListBatchPredictionJobsRequest, response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -449,13 +239,10 @@ def __init__( The initial request object. response (:class:`~.job_service.ListBatchPredictionJobsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = job_service.ListBatchPredictionJobsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -465,7 +252,7 @@ def pages(self) -> Iterable[job_service.ListBatchPredictionJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: @@ -474,69 +261,3 @@ def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - -class ListBatchPredictionJobsAsyncPager: - """A pager for iterating through ``list_batch_prediction_jobs`` requests. - - This class thinly wraps an initial - :class:`~.job_service.ListBatchPredictionJobsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``batch_prediction_jobs`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListBatchPredictionJobs`` requests and continue to iterate - through the ``batch_prediction_jobs`` field on the - corresponding responses. - - All the usual :class:`~.job_service.ListBatchPredictionJobsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], - request: job_service.ListBatchPredictionJobsRequest, - response: job_service.ListBatchPredictionJobsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.job_service.ListBatchPredictionJobsRequest`): - The initial request object. - response (:class:`~.job_service.ListBatchPredictionJobsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = job_service.ListBatchPredictionJobsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[job_service.ListBatchPredictionJobsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[batch_prediction_job.BatchPredictionJob]: - async def async_generator(): - async for page in self.pages: - for response in page.batch_prediction_jobs: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py index ca4d929cb5..2f081266a0 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py @@ -20,17 +20,14 @@ from .base import JobServiceTransport from .grpc import JobServiceGrpcTransport -from .grpc_asyncio import JobServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] _transport_registry["grpc"] = JobServiceGrpcTransport -_transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport __all__ = ( "JobServiceTransport", "JobServiceGrpcTransport", - "JobServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index 3d1f0be59b..6e11bb87ea 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -17,12 +17,8 @@ import abc import typing -import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore +from google import auth from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -45,17 +41,7 @@ from google.protobuf import empty_pb2 as empty # type: ignore -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -class JobServiceTransport(abc.ABC): +class JobServiceTransport(metaclass=abc.ABCMeta): """Abstract transport class for JobService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -65,11 +51,6 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, ) -> None: """Instantiate the transport. @@ -80,17 +61,6 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -99,338 +69,174 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) - - if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) - - elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.create_custom_job: gapic_v1.method.wrap_method( - self.create_custom_job, default_timeout=5.0, client_info=client_info, - ), - self.get_custom_job: gapic_v1.method.wrap_method( - self.get_custom_job, default_timeout=5.0, client_info=client_info, - ), - self.list_custom_jobs: gapic_v1.method.wrap_method( - self.list_custom_jobs, default_timeout=5.0, client_info=client_info, - ), - self.delete_custom_job: gapic_v1.method.wrap_method( - self.delete_custom_job, default_timeout=5.0, client_info=client_info, - ), - self.cancel_custom_job: gapic_v1.method.wrap_method( - self.cancel_custom_job, default_timeout=5.0, client_info=client_info, - ), - self.create_data_labeling_job: gapic_v1.method.wrap_method( - self.create_data_labeling_job, - default_timeout=5.0, - client_info=client_info, - ), - self.get_data_labeling_job: gapic_v1.method.wrap_method( - self.get_data_labeling_job, - default_timeout=5.0, - client_info=client_info, - ), - self.list_data_labeling_jobs: gapic_v1.method.wrap_method( - self.list_data_labeling_jobs, - default_timeout=5.0, - client_info=client_info, - ), - self.delete_data_labeling_job: gapic_v1.method.wrap_method( - self.delete_data_labeling_job, - default_timeout=5.0, - client_info=client_info, - ), - self.cancel_data_labeling_job: gapic_v1.method.wrap_method( - self.cancel_data_labeling_job, - default_timeout=5.0, - client_info=client_info, - ), - self.create_hyperparameter_tuning_job: gapic_v1.method.wrap_method( - self.create_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=client_info, - ), - self.get_hyperparameter_tuning_job: gapic_v1.method.wrap_method( - self.get_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=client_info, - ), - self.list_hyperparameter_tuning_jobs: gapic_v1.method.wrap_method( - self.list_hyperparameter_tuning_jobs, - default_timeout=5.0, - client_info=client_info, - ), - self.delete_hyperparameter_tuning_job: gapic_v1.method.wrap_method( - self.delete_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=client_info, - ), - self.cancel_hyperparameter_tuning_job: gapic_v1.method.wrap_method( - self.cancel_hyperparameter_tuning_job, - default_timeout=5.0, - client_info=client_info, - ), - self.create_batch_prediction_job: gapic_v1.method.wrap_method( - self.create_batch_prediction_job, - default_timeout=5.0, - client_info=client_info, - ), - self.get_batch_prediction_job: gapic_v1.method.wrap_method( - self.get_batch_prediction_job, - default_timeout=5.0, - client_info=client_info, - ), - self.list_batch_prediction_jobs: gapic_v1.method.wrap_method( - self.list_batch_prediction_jobs, - default_timeout=5.0, - client_info=client_info, - ), - self.delete_batch_prediction_job: gapic_v1.method.wrap_method( - self.delete_batch_prediction_job, - default_timeout=5.0, - client_info=client_info, - ), - self.cancel_batch_prediction_job: gapic_v1.method.wrap_method( - self.cancel_batch_prediction_job, - default_timeout=5.0, - client_info=client_info, - ), - } - @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError() + raise NotImplementedError @property def create_custom_job( self, ) -> typing.Callable[ - [job_service.CreateCustomJobRequest], - typing.Union[ - gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] - ], + [job_service.CreateCustomJobRequest], gca_custom_job.CustomJob ]: - raise NotImplementedError() + raise NotImplementedError @property def get_custom_job( self, - ) -> typing.Callable[ - [job_service.GetCustomJobRequest], - typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: + raise NotImplementedError @property def list_custom_jobs( self, ) -> typing.Callable[ - [job_service.ListCustomJobsRequest], - typing.Union[ - job_service.ListCustomJobsResponse, - typing.Awaitable[job_service.ListCustomJobsResponse], - ], + [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse ]: - raise NotImplementedError() + raise NotImplementedError @property def delete_custom_job( self, - ) -> typing.Callable[ - [job_service.DeleteCustomJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: + raise NotImplementedError @property def cancel_custom_job( self, - ) -> typing.Callable[ - [job_service.CancelCustomJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[job_service.CancelCustomJobRequest], empty.Empty]: + raise NotImplementedError @property def create_data_labeling_job( self, ) -> typing.Callable[ [job_service.CreateDataLabelingJobRequest], - typing.Union[ - gca_data_labeling_job.DataLabelingJob, - typing.Awaitable[gca_data_labeling_job.DataLabelingJob], - ], + gca_data_labeling_job.DataLabelingJob, ]: - raise NotImplementedError() + raise NotImplementedError @property def get_data_labeling_job( self, ) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], - typing.Union[ - data_labeling_job.DataLabelingJob, - typing.Awaitable[data_labeling_job.DataLabelingJob], - ], + [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob ]: - raise NotImplementedError() + raise NotImplementedError @property def list_data_labeling_jobs( self, ) -> typing.Callable[ [job_service.ListDataLabelingJobsRequest], - typing.Union[ - job_service.ListDataLabelingJobsResponse, - typing.Awaitable[job_service.ListDataLabelingJobsResponse], - ], + job_service.ListDataLabelingJobsResponse, ]: - raise NotImplementedError() + raise NotImplementedError @property def delete_data_labeling_job( self, ) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [job_service.DeleteDataLabelingJobRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError @property def cancel_data_labeling_job( self, - ) -> typing.Callable[ - [job_service.CancelDataLabelingJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: + raise NotImplementedError @property def create_hyperparameter_tuning_job( self, ) -> typing.Callable[ [job_service.CreateHyperparameterTuningJobRequest], - typing.Union[ - gca_hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], - ], + gca_hyperparameter_tuning_job.HyperparameterTuningJob, ]: - raise NotImplementedError() + raise NotImplementedError @property def get_hyperparameter_tuning_job( self, ) -> typing.Callable[ [job_service.GetHyperparameterTuningJobRequest], - typing.Union[ - hyperparameter_tuning_job.HyperparameterTuningJob, - typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], - ], + hyperparameter_tuning_job.HyperparameterTuningJob, ]: - raise NotImplementedError() + raise NotImplementedError @property def list_hyperparameter_tuning_jobs( self, ) -> typing.Callable[ [job_service.ListHyperparameterTuningJobsRequest], - typing.Union[ - job_service.ListHyperparameterTuningJobsResponse, - typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], - ], + job_service.ListHyperparameterTuningJobsResponse, ]: - raise NotImplementedError() + raise NotImplementedError @property def delete_hyperparameter_tuning_job( self, ) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError @property def cancel_hyperparameter_tuning_job( self, ) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + [job_service.CancelHyperparameterTuningJobRequest], empty.Empty ]: - raise NotImplementedError() + raise NotImplementedError @property def create_batch_prediction_job( self, ) -> typing.Callable[ [job_service.CreateBatchPredictionJobRequest], - typing.Union[ - gca_batch_prediction_job.BatchPredictionJob, - typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], - ], + gca_batch_prediction_job.BatchPredictionJob, ]: - raise NotImplementedError() + raise NotImplementedError @property def get_batch_prediction_job( self, ) -> typing.Callable[ [job_service.GetBatchPredictionJobRequest], - typing.Union[ - batch_prediction_job.BatchPredictionJob, - typing.Awaitable[batch_prediction_job.BatchPredictionJob], - ], + batch_prediction_job.BatchPredictionJob, ]: - raise NotImplementedError() + raise NotImplementedError @property def list_batch_prediction_jobs( self, ) -> typing.Callable[ [job_service.ListBatchPredictionJobsRequest], - typing.Union[ - job_service.ListBatchPredictionJobsResponse, - typing.Awaitable[job_service.ListBatchPredictionJobsResponse], - ], + job_service.ListBatchPredictionJobsResponse, ]: - raise NotImplementedError() + raise NotImplementedError @property def delete_batch_prediction_job( self, ) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [job_service.DeleteBatchPredictionJobRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError @property def cancel_batch_prediction_job( self, - ) -> typing.Callable[ - [job_service.CancelBatchPredictionJobRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: + raise NotImplementedError __all__ = ("JobServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index 99c2179ef7..cdb049c585 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -15,15 +15,11 @@ # limitations under the License. # -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -45,7 +41,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .base import JobServiceTransport class JobServiceGrpcTransport(JobServiceTransport): @@ -61,21 +57,12 @@ class JobServiceGrpcTransport(JobServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + channel: grpc.Channel = None ) -> None: """Instantiate the transport. @@ -87,119 +74,28 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ + # Sanity check: Ensure that channel and credentials are not both + # provided. if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - + # Run the base constructor. + super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # If a channel was explicitly provided, set it. + if channel: + self._grpc_channel = channel @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, + **kwargs ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -209,31 +105,13 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, + host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs ) @property @@ -243,6 +121,13 @@ def grpc_channel(self) -> grpc.Channel: This property caches on the instance; repeated calls return the same channel. """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, "_grpc_channel"): + self._grpc_channel = self.create_channel( + self._host, credentials=self._credentials, + ) + # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py deleted file mode 100644 index eea7a67ae1..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,885 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple - -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) -from google.cloud.aiplatform_v1beta1.types import custom_job -from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) -from google.cloud.aiplatform_v1beta1.types import job_service -from google.longrunning import operations_pb2 as operations # type: ignore -from google.protobuf import empty_pb2 as empty # type: ignore - -from .base import JobServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import JobServiceGrpcTransport - - -class JobServiceGrpcAsyncIOTransport(JobServiceTransport): - """gRPC AsyncIO backend transport for JobService. - - A service for creating and managing AI Platform's jobs. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - address (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - scopes = scopes or cls.AUTH_SCOPES - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, - ) - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. - credentials = False - - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def operations_client(self) -> operations_v1.OperationsAsyncClient: - """Create the client designed to process long-running operations. - - This property caches on the instance; repeated calls return the same - client. - """ - # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( - self.grpc_channel - ) - - # Return the client from cache. - return self.__dict__["operations_client"] - - @property - def create_custom_job( - self, - ) -> Callable[ - [job_service.CreateCustomJobRequest], Awaitable[gca_custom_job.CustomJob] - ]: - r"""Return a callable for the create custom job method over gRPC. - - Creates a CustomJob. A created CustomJob right away - will be attempted to be run. - - Returns: - Callable[[~.CreateCustomJobRequest], - Awaitable[~.CustomJob]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "create_custom_job" not in self._stubs: - self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", - request_serializer=job_service.CreateCustomJobRequest.serialize, - response_deserializer=gca_custom_job.CustomJob.deserialize, - ) - return self._stubs["create_custom_job"] - - @property - def get_custom_job( - self, - ) -> Callable[[job_service.GetCustomJobRequest], Awaitable[custom_job.CustomJob]]: - r"""Return a callable for the get custom job method over gRPC. - - Gets a CustomJob. - - Returns: - Callable[[~.GetCustomJobRequest], - Awaitable[~.CustomJob]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_custom_job" not in self._stubs: - self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", - request_serializer=job_service.GetCustomJobRequest.serialize, - response_deserializer=custom_job.CustomJob.deserialize, - ) - return self._stubs["get_custom_job"] - - @property - def list_custom_jobs( - self, - ) -> Callable[ - [job_service.ListCustomJobsRequest], - Awaitable[job_service.ListCustomJobsResponse], - ]: - r"""Return a callable for the list custom jobs method over gRPC. - - Lists CustomJobs in a Location. - - Returns: - Callable[[~.ListCustomJobsRequest], - Awaitable[~.ListCustomJobsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_custom_jobs" not in self._stubs: - self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", - request_serializer=job_service.ListCustomJobsRequest.serialize, - response_deserializer=job_service.ListCustomJobsResponse.deserialize, - ) - return self._stubs["list_custom_jobs"] - - @property - def delete_custom_job( - self, - ) -> Callable[ - [job_service.DeleteCustomJobRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the delete custom job method over gRPC. - - Deletes a CustomJob. - - Returns: - Callable[[~.DeleteCustomJobRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_custom_job" not in self._stubs: - self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", - request_serializer=job_service.DeleteCustomJobRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_custom_job"] - - @property - def cancel_custom_job( - self, - ) -> Callable[[job_service.CancelCustomJobRequest], Awaitable[empty.Empty]]: - r"""Return a callable for the cancel custom job method over gRPC. - - Cancels a CustomJob. Starts asynchronous cancellation on the - CustomJob. The server makes a best effort to cancel the job, but - success is not guaranteed. Clients can use - [JobService.GetCustomJob][google.cloud.aiplatform.v1beta1.JobService.GetCustomJob] - or other methods to check whether the cancellation succeeded or - whether the job completed despite cancellation. On successful - cancellation, the CustomJob is not deleted; instead it becomes a - job with a - [CustomJob.error][google.cloud.aiplatform.v1beta1.CustomJob.error] - value with a [google.rpc.Status.code][google.rpc.Status.code] of - 1, corresponding to ``Code.CANCELLED``, and - [CustomJob.state][google.cloud.aiplatform.v1beta1.CustomJob.state] - is set to ``CANCELLED``. - - Returns: - Callable[[~.CancelCustomJobRequest], - Awaitable[~.Empty]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "cancel_custom_job" not in self._stubs: - self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", - request_serializer=job_service.CancelCustomJobRequest.serialize, - response_deserializer=empty.Empty.FromString, - ) - return self._stubs["cancel_custom_job"] - - @property - def create_data_labeling_job( - self, - ) -> Callable[ - [job_service.CreateDataLabelingJobRequest], - Awaitable[gca_data_labeling_job.DataLabelingJob], - ]: - r"""Return a callable for the create data labeling job method over gRPC. - - Creates a DataLabelingJob. - - Returns: - Callable[[~.CreateDataLabelingJobRequest], - Awaitable[~.DataLabelingJob]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "create_data_labeling_job" not in self._stubs: - self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", - request_serializer=job_service.CreateDataLabelingJobRequest.serialize, - response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, - ) - return self._stubs["create_data_labeling_job"] - - @property - def get_data_labeling_job( - self, - ) -> Callable[ - [job_service.GetDataLabelingJobRequest], - Awaitable[data_labeling_job.DataLabelingJob], - ]: - r"""Return a callable for the get data labeling job method over gRPC. - - Gets a DataLabelingJob. - - Returns: - Callable[[~.GetDataLabelingJobRequest], - Awaitable[~.DataLabelingJob]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_data_labeling_job" not in self._stubs: - self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", - request_serializer=job_service.GetDataLabelingJobRequest.serialize, - response_deserializer=data_labeling_job.DataLabelingJob.deserialize, - ) - return self._stubs["get_data_labeling_job"] - - @property - def list_data_labeling_jobs( - self, - ) -> Callable[ - [job_service.ListDataLabelingJobsRequest], - Awaitable[job_service.ListDataLabelingJobsResponse], - ]: - r"""Return a callable for the list data labeling jobs method over gRPC. - - Lists DataLabelingJobs in a Location. - - Returns: - Callable[[~.ListDataLabelingJobsRequest], - Awaitable[~.ListDataLabelingJobsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_data_labeling_jobs" not in self._stubs: - self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", - request_serializer=job_service.ListDataLabelingJobsRequest.serialize, - response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, - ) - return self._stubs["list_data_labeling_jobs"] - - @property - def delete_data_labeling_job( - self, - ) -> Callable[ - [job_service.DeleteDataLabelingJobRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the delete data labeling job method over gRPC. - - Deletes a DataLabelingJob. - - Returns: - Callable[[~.DeleteDataLabelingJobRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_data_labeling_job" not in self._stubs: - self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", - request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_data_labeling_job"] - - @property - def cancel_data_labeling_job( - self, - ) -> Callable[[job_service.CancelDataLabelingJobRequest], Awaitable[empty.Empty]]: - r"""Return a callable for the cancel data labeling job method over gRPC. - - Cancels a DataLabelingJob. Success of cancellation is - not guaranteed. - - Returns: - Callable[[~.CancelDataLabelingJobRequest], - Awaitable[~.Empty]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "cancel_data_labeling_job" not in self._stubs: - self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", - request_serializer=job_service.CancelDataLabelingJobRequest.serialize, - response_deserializer=empty.Empty.FromString, - ) - return self._stubs["cancel_data_labeling_job"] - - @property - def create_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CreateHyperparameterTuningJobRequest], - Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], - ]: - r"""Return a callable for the create hyperparameter tuning - job method over gRPC. - - Creates a HyperparameterTuningJob - - Returns: - Callable[[~.CreateHyperparameterTuningJobRequest], - Awaitable[~.HyperparameterTuningJob]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "create_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "create_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", - request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, - response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, - ) - return self._stubs["create_hyperparameter_tuning_job"] - - @property - def get_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.GetHyperparameterTuningJobRequest], - Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], - ]: - r"""Return a callable for the get hyperparameter tuning job method over gRPC. - - Gets a HyperparameterTuningJob - - Returns: - Callable[[~.GetHyperparameterTuningJobRequest], - Awaitable[~.HyperparameterTuningJob]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "get_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", - request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, - response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, - ) - return self._stubs["get_hyperparameter_tuning_job"] - - @property - def list_hyperparameter_tuning_jobs( - self, - ) -> Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - Awaitable[job_service.ListHyperparameterTuningJobsResponse], - ]: - r"""Return a callable for the list hyperparameter tuning - jobs method over gRPC. - - Lists HyperparameterTuningJobs in a Location. - - Returns: - Callable[[~.ListHyperparameterTuningJobsRequest], - Awaitable[~.ListHyperparameterTuningJobsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_hyperparameter_tuning_jobs" not in self._stubs: - self._stubs[ - "list_hyperparameter_tuning_jobs" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", - request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, - response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, - ) - return self._stubs["list_hyperparameter_tuning_jobs"] - - @property - def delete_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], - Awaitable[operations.Operation], - ]: - r"""Return a callable for the delete hyperparameter tuning - job method over gRPC. - - Deletes a HyperparameterTuningJob. - - Returns: - Callable[[~.DeleteHyperparameterTuningJobRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "delete_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", - request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_hyperparameter_tuning_job"] - - @property - def cancel_hyperparameter_tuning_job( - self, - ) -> Callable[ - [job_service.CancelHyperparameterTuningJobRequest], Awaitable[empty.Empty] - ]: - r"""Return a callable for the cancel hyperparameter tuning - job method over gRPC. - - Cancels a HyperparameterTuningJob. Starts asynchronous - cancellation on the HyperparameterTuningJob. The server makes a - best effort to cancel the job, but success is not guaranteed. - Clients can use - [JobService.GetHyperparameterTuningJob][google.cloud.aiplatform.v1beta1.JobService.GetHyperparameterTuningJob] - or other methods to check whether the cancellation succeeded or - whether the job completed despite cancellation. On successful - cancellation, the HyperparameterTuningJob is not deleted; - instead it becomes a job with a - [HyperparameterTuningJob.error][google.cloud.aiplatform.v1beta1.HyperparameterTuningJob.error] - value with a [google.rpc.Status.code][google.rpc.Status.code] of - 1, corresponding to ``Code.CANCELLED``, and - [HyperparameterTuningJob.state][google.cloud.aiplatform.v1beta1.HyperparameterTuningJob.state] - is set to ``CANCELLED``. - - Returns: - Callable[[~.CancelHyperparameterTuningJobRequest], - Awaitable[~.Empty]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "cancel_hyperparameter_tuning_job" not in self._stubs: - self._stubs[ - "cancel_hyperparameter_tuning_job" - ] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", - request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, - response_deserializer=empty.Empty.FromString, - ) - return self._stubs["cancel_hyperparameter_tuning_job"] - - @property - def create_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CreateBatchPredictionJobRequest], - Awaitable[gca_batch_prediction_job.BatchPredictionJob], - ]: - r"""Return a callable for the create batch prediction job method over gRPC. - - Creates a BatchPredictionJob. A BatchPredictionJob - once created will right away be attempted to start. - - Returns: - Callable[[~.CreateBatchPredictionJobRequest], - Awaitable[~.BatchPredictionJob]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "create_batch_prediction_job" not in self._stubs: - self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", - request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, - response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, - ) - return self._stubs["create_batch_prediction_job"] - - @property - def get_batch_prediction_job( - self, - ) -> Callable[ - [job_service.GetBatchPredictionJobRequest], - Awaitable[batch_prediction_job.BatchPredictionJob], - ]: - r"""Return a callable for the get batch prediction job method over gRPC. - - Gets a BatchPredictionJob - - Returns: - Callable[[~.GetBatchPredictionJobRequest], - Awaitable[~.BatchPredictionJob]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_batch_prediction_job" not in self._stubs: - self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", - request_serializer=job_service.GetBatchPredictionJobRequest.serialize, - response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, - ) - return self._stubs["get_batch_prediction_job"] - - @property - def list_batch_prediction_jobs( - self, - ) -> Callable[ - [job_service.ListBatchPredictionJobsRequest], - Awaitable[job_service.ListBatchPredictionJobsResponse], - ]: - r"""Return a callable for the list batch prediction jobs method over gRPC. - - Lists BatchPredictionJobs in a Location. - - Returns: - Callable[[~.ListBatchPredictionJobsRequest], - Awaitable[~.ListBatchPredictionJobsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_batch_prediction_jobs" not in self._stubs: - self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", - request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, - response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, - ) - return self._stubs["list_batch_prediction_jobs"] - - @property - def delete_batch_prediction_job( - self, - ) -> Callable[ - [job_service.DeleteBatchPredictionJobRequest], Awaitable[operations.Operation] - ]: - r"""Return a callable for the delete batch prediction job method over gRPC. - - Deletes a BatchPredictionJob. Can only be called on - jobs that already finished. - - Returns: - Callable[[~.DeleteBatchPredictionJobRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_batch_prediction_job" not in self._stubs: - self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", - request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_batch_prediction_job"] - - @property - def cancel_batch_prediction_job( - self, - ) -> Callable[ - [job_service.CancelBatchPredictionJobRequest], Awaitable[empty.Empty] - ]: - r"""Return a callable for the cancel batch prediction job method over gRPC. - - Cancels a BatchPredictionJob. - - Starts asynchronous cancellation on the BatchPredictionJob. The - server makes the best effort to cancel the job, but success is - not guaranteed. Clients can use - [JobService.GetBatchPredictionJob][google.cloud.aiplatform.v1beta1.JobService.GetBatchPredictionJob] - or other methods to check whether the cancellation succeeded or - whether the job completed despite cancellation. On a successful - cancellation, the BatchPredictionJob is not deleted;instead its - [BatchPredictionJob.state][google.cloud.aiplatform.v1beta1.BatchPredictionJob.state] - is set to ``CANCELLED``. Any files already outputted by the job - are not deleted. - - Returns: - Callable[[~.CancelBatchPredictionJobRequest], - Awaitable[~.Empty]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "cancel_batch_prediction_job" not in self._stubs: - self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", - request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, - response_deserializer=empty.Empty.FromString, - ) - return self._stubs["cancel_batch_prediction_job"] - - -__all__ = ("JobServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py index b39295ebfe..b0d80fcc98 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py @@ -16,9 +16,5 @@ # from .client import ModelServiceClient -from .async_client import ModelServiceAsyncClient -__all__ = ( - "ModelServiceClient", - "ModelServiceAsyncClient", -) +__all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py deleted file mode 100644 index 19e3a8d973..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ /dev/null @@ -1,963 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from collections import OrderedDict -import functools -import re -from typing import Dict, Sequence, Tuple, Type, Union -import pkg_resources - -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore -from google.cloud.aiplatform_v1beta1.services.model_service import pagers -from google.cloud.aiplatform_v1beta1.types import deployed_model_ref -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import model_evaluation -from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice -from google.cloud.aiplatform_v1beta1.types import model_service -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.protobuf import empty_pb2 as empty # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - -from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport -from .client import ModelServiceClient - - -class ModelServiceAsyncClient: - """A service for managing AI Platform's machine learning Models.""" - - _client: ModelServiceClient - - DEFAULT_ENDPOINT = ModelServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = ModelServiceClient.DEFAULT_MTLS_ENDPOINT - - model_path = staticmethod(ModelServiceClient.model_path) - parse_model_path = staticmethod(ModelServiceClient.parse_model_path) - - from_service_account_file = ModelServiceClient.from_service_account_file - from_service_account_json = from_service_account_file - - get_transport_class = functools.partial( - type(ModelServiceClient).get_transport_class, type(ModelServiceClient) - ) - - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, ModelServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the model service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.ModelServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - - self._client = ModelServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - ) - - async def upload_model( - self, - request: model_service.UploadModelRequest = None, - *, - parent: str = None, - model: gca_model.Model = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Uploads a Model artifact into AI Platform. - - Args: - request (:class:`~.model_service.UploadModelRequest`): - The request object. Request message for - [ModelService.UploadModel][google.cloud.aiplatform.v1beta1.ModelService.UploadModel]. - parent (:class:`str`): - Required. The resource name of the Location into which - to upload the Model. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - model (:class:`~.gca_model.Model`): - Required. The Model to create. - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.model_service.UploadModelResponse``: Response - message of - [ModelService.UploadModel][google.cloud.aiplatform.v1beta1.ModelService.UploadModel] - operation. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, model]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.UploadModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if model is not None: - request.model = model - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.upload_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - model_service.UploadModelResponse, - metadata_type=model_service.UploadModelOperationMetadata, - ) - - # Done; return the response. - return response - - async def get_model( - self, - request: model_service.GetModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: - r"""Gets a Model. - - Args: - request (:class:`~.model_service.GetModelRequest`): - The request object. Request message for - [ModelService.GetModel][google.cloud.aiplatform.v1beta1.ModelService.GetModel]. - name (:class:`str`): - Required. The name of the Model resource. Format: - ``projects/{project}/locations/{location}/models/{model}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.model.Model: - A trained machine learning Model. - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.GetModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_models( - self, - request: model_service.ListModelsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsAsyncPager: - r"""Lists Models in a Location. - - Args: - request (:class:`~.model_service.ListModelsRequest`): - The request object. Request message for - [ModelService.ListModels][google.cloud.aiplatform.v1beta1.ModelService.ListModels]. - parent (:class:`str`): - Required. The resource name of the Location to list the - Models from. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListModelsAsyncPager: - Response message for - [ModelService.ListModels][google.cloud.aiplatform.v1beta1.ModelService.ListModels] - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.ListModelsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_models, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListModelsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def update_model( - self, - request: model_service.UpdateModelRequest = None, - *, - model: gca_model.Model = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_model.Model: - r"""Updates a Model. - - Args: - request (:class:`~.model_service.UpdateModelRequest`): - The request object. Request message for - [ModelService.UpdateModel][google.cloud.aiplatform.v1beta1.ModelService.UpdateModel]. - model (:class:`~.gca_model.Model`): - Required. The Model which replaces - the resource on the server. - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - update_mask (:class:`~.field_mask.FieldMask`): - Required. The update mask applies to the resource. For - the ``FieldMask`` definition, see - - [FieldMask](https: - //developers.google.com/protocol-buffers // - /docs/reference/google.protobuf#fieldmask). - This corresponds to the ``update_mask`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.gca_model.Model: - A trained machine learning Model. - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([model, update_mask]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.UpdateModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if model is not None: - request.model = model - if update_mask is not None: - request.update_mask = update_mask - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("model.name", request.model.name),) - ), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def delete_model( - self, - request: model_service.DeleteModelRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes a Model. - Note: Model can only be deleted if there are no - DeployedModels created from it. - - Args: - request (:class:`~.model_service.DeleteModelRequest`): - The request object. Request message for - [ModelService.DeleteModel][google.cloud.aiplatform.v1beta1.ModelService.DeleteModel]. - name (:class:`str`): - Required. The name of the Model resource to be deleted. - Format: - ``projects/{project}/locations/{location}/models/{model}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.DeleteModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def export_model( - self, - request: model_service.ExportModelRequest = None, - *, - name: str = None, - output_config: model_service.ExportModelRequest.OutputConfig = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Exports a trained, exportable, Model to a location specified by - the user. A Model is considered to be exportable if it has at - least one [supported export - format][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. - - Args: - request (:class:`~.model_service.ExportModelRequest`): - The request object. Request message for - [ModelService.ExportModel][google.cloud.aiplatform.v1beta1.ModelService.ExportModel]. - name (:class:`str`): - Required. The resource name of the Model to export. - Format: - ``projects/{project}/locations/{location}/models/{model}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - output_config (:class:`~.model_service.ExportModelRequest.OutputConfig`): - Required. The desired output location - and configuration. - This corresponds to the ``output_config`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.model_service.ExportModelResponse``: Response - message of - [ModelService.ExportModel][google.cloud.aiplatform.v1beta1.ModelService.ExportModel] - operation. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name, output_config]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.ExportModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - if output_config is not None: - request.output_config = output_config - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.export_model, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - model_service.ExportModelResponse, - metadata_type=model_service.ExportModelOperationMetadata, - ) - - # Done; return the response. - return response - - async def get_model_evaluation( - self, - request: model_service.GetModelEvaluationRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation.ModelEvaluation: - r"""Gets a ModelEvaluation. - - Args: - request (:class:`~.model_service.GetModelEvaluationRequest`): - The request object. Request message for - [ModelService.GetModelEvaluation][google.cloud.aiplatform.v1beta1.ModelService.GetModelEvaluation]. - name (:class:`str`): - Required. The name of the ModelEvaluation resource. - Format: - - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.model_evaluation.ModelEvaluation: - A collection of metrics calculated by - comparing Model's predictions on all of - the test data against annotations from - the test data. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.GetModelEvaluationRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_evaluation, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_model_evaluations( - self, - request: model_service.ListModelEvaluationsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationsAsyncPager: - r"""Lists ModelEvaluations in a Model. - - Args: - request (:class:`~.model_service.ListModelEvaluationsRequest`): - The request object. Request message for - [ModelService.ListModelEvaluations][google.cloud.aiplatform.v1beta1.ModelService.ListModelEvaluations]. - parent (:class:`str`): - Required. The resource name of the Model to list the - ModelEvaluations from. Format: - ``projects/{project}/locations/{location}/models/{model}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListModelEvaluationsAsyncPager: - Response message for - [ModelService.ListModelEvaluations][google.cloud.aiplatform.v1beta1.ModelService.ListModelEvaluations]. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.ListModelEvaluationsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_evaluations, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListModelEvaluationsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def get_model_evaluation_slice( - self, - request: model_service.GetModelEvaluationSliceRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model_evaluation_slice.ModelEvaluationSlice: - r"""Gets a ModelEvaluationSlice. - - Args: - request (:class:`~.model_service.GetModelEvaluationSliceRequest`): - The request object. Request message for - [ModelService.GetModelEvaluationSlice][google.cloud.aiplatform.v1beta1.ModelService.GetModelEvaluationSlice]. - name (:class:`str`): - Required. The name of the ModelEvaluationSlice resource. - Format: - - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.model_evaluation_slice.ModelEvaluationSlice: - A collection of metrics calculated by - comparing Model's predictions on a slice - of the test data against ground truth - annotations. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.GetModelEvaluationSliceRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model_evaluation_slice, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_model_evaluation_slices( - self, - request: model_service.ListModelEvaluationSlicesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelEvaluationSlicesAsyncPager: - r"""Lists ModelEvaluationSlices in a ModelEvaluation. - - Args: - request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): - The request object. Request message for - [ModelService.ListModelEvaluationSlices][google.cloud.aiplatform.v1beta1.ModelService.ListModelEvaluationSlices]. - parent (:class:`str`): - Required. The resource name of the ModelEvaluation to - list the ModelEvaluationSlices from. Format: - - ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListModelEvaluationSlicesAsyncPager: - Response message for - [ModelService.ListModelEvaluationSlices][google.cloud.aiplatform.v1beta1.ModelService.ListModelEvaluationSlices]. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = model_service.ListModelEvaluationSlicesRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_model_evaluation_slices, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListModelEvaluationSlicesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -__all__ = ("ModelServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 8c8fe7e16e..801456a7c5 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -16,24 +16,17 @@ # from collections import OrderedDict -from distutils import util -import os -import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -from google.api_core import client_options as client_options_lib # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore +from google.api_core import operation as ga_operation from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.types import deployed_model_ref from google.cloud.aiplatform_v1beta1.types import explanation @@ -48,9 +41,8 @@ from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.base import ModelServiceTransport from .transports.grpc import ModelServiceGrpcTransport -from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport class ModelServiceClientMeta(type): @@ -63,7 +55,6 @@ class ModelServiceClientMeta(type): _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] _transport_registry["grpc"] = ModelServiceGrpcTransport - _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: """Return an appropriate transport class. @@ -87,38 +78,8 @@ def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: class ModelServiceClient(metaclass=ModelServiceClientMeta): """A service for managing AI Platform's machine learning Models.""" - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT + DEFAULT_OPTIONS = ClientOptions.ClientOptions( + api_endpoint="aiplatform.googleapis.com" ) @classmethod @@ -148,22 +109,12 @@ def model_path(project: str, location: str, model: str,) -> str: project=project, location=location, model=model, ) - @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: - """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) - return m.groupdict() if m else {} - def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = None, + client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, ) -> None: """Instantiate the model service client. @@ -176,102 +127,26 @@ def __init__( transport (Union[str, ~.ModelServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. + client_options (ClientOptions): Custom options for the client. """ if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - - # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) - - ssl_credentials = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - is_mtls = True - else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" - ) + client_options = ClientOptions.from_dict(client_options) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, ModelServiceTransport): - # transport is a ModelServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: - raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." - ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, - quota_project_id=client_options.quota_project_id, - client_info=client_info, + host=client_options.api_endpoint or "aiplatform.googleapis.com", ) def upload_model( @@ -323,36 +198,28 @@ def upload_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, model]) - if request is not None and has_flattened_params: + if request is not None and any([parent, model]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.UploadModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.UploadModelRequest): - request = model_service.UploadModelRequest(request) + request = model_service.UploadModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if model is not None: - request.model = model + if parent is not None: + request.parent = parent + if model is not None: + request.model = model # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.upload_model] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.upload_model, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -404,29 +271,25 @@ def get_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.GetModelRequest): - request = model_service.GetModelRequest(request) + request = model_service.GetModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model] + rpc = gapic_v1.method.wrap_method( + self._transport.get_model, default_timeout=None, client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -481,29 +344,25 @@ def list_models( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.ListModelsRequest): - request = model_service.ListModelsRequest(request) + request = model_service.ListModelsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_models] + rpc = gapic_v1.method.wrap_method( + self._transport.list_models, default_timeout=None, client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -517,7 +376,7 @@ def list_models( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -569,38 +428,28 @@ def update_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, update_mask]) - if request is not None and has_flattened_params: + if request is not None and any([model, update_mask]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.UpdateModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.UpdateModelRequest): - request = model_service.UpdateModelRequest(request) + request = model_service.UpdateModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if model is not None: - request.model = model - if update_mask is not None: - request.update_mask = update_mask + if model is not None: + request.model = model + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.update_model] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("model.name", request.model.name),) - ), + rpc = gapic_v1.method.wrap_method( + self._transport.update_model, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -664,34 +513,26 @@ def delete_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.DeleteModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.DeleteModelRequest): - request = model_service.DeleteModelRequest(request) + request = model_service.DeleteModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_model] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_model, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -761,36 +602,28 @@ def export_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name, output_config]) - if request is not None and has_flattened_params: + if request is not None and any([name, output_config]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ExportModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.ExportModelRequest): - request = model_service.ExportModelRequest(request) + request = model_service.ExportModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name - if output_config is not None: - request.output_config = output_config + if name is not None: + request.name = name + if output_config is not None: + request.output_config = output_config # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.export_model] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.export_model, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -848,29 +681,27 @@ def get_model_evaluation( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelEvaluationRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.GetModelEvaluationRequest): - request = model_service.GetModelEvaluationRequest(request) + request = model_service.GetModelEvaluationRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation] + rpc = gapic_v1.method.wrap_method( + self._transport.get_model_evaluation, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -925,29 +756,27 @@ def list_model_evaluations( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelEvaluationsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.ListModelEvaluationsRequest): - request = model_service.ListModelEvaluationsRequest(request) + request = model_service.ListModelEvaluationsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_model_evaluations] + rpc = gapic_v1.method.wrap_method( + self._transport.list_model_evaluations, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -961,7 +790,7 @@ def list_model_evaluations( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -1008,31 +837,27 @@ def get_model_evaluation_slice( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelEvaluationSliceRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.GetModelEvaluationSliceRequest): - request = model_service.GetModelEvaluationSliceRequest(request) + request = model_service.GetModelEvaluationSliceRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.get_model_evaluation_slice - ] + rpc = gapic_v1.method.wrap_method( + self._transport.get_model_evaluation_slice, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1088,31 +913,27 @@ def list_model_evaluation_slices( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelEvaluationSlicesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.ListModelEvaluationSlicesRequest): - request = model_service.ListModelEvaluationSlicesRequest(request) + request = model_service.ListModelEvaluationSlicesRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[ - self._transport.list_model_evaluation_slices - ] + rpc = gapic_v1.method.wrap_method( + self._transport.list_model_evaluation_slices, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -1126,7 +947,7 @@ def list_model_evaluation_slices( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -1134,13 +955,13 @@ def list_model_evaluation_slices( try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + _client_info = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + _client_info = gapic_v1.client_info.ClientInfo() __all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index 1ab3aacb91..4169c27b85 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import Any, Callable, Iterable from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model_evaluation @@ -43,11 +43,11 @@ class ListModelsPager: def __init__( self, - method: Callable[..., model_service.ListModelsResponse], + method: Callable[ + [model_service.ListModelsRequest], model_service.ListModelsResponse + ], request: model_service.ListModelsRequest, response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -58,13 +58,10 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -74,7 +71,7 @@ def pages(self) -> Iterable[model_service.ListModelsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[model.Model]: @@ -85,72 +82,6 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) -class ListModelsAsyncPager: - """A pager for iterating through ``list_models`` requests. - - This class thinly wraps an initial - :class:`~.model_service.ListModelsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``models`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListModels`` requests and continue to iterate - through the ``models`` field on the - corresponding responses. - - All the usual :class:`~.model_service.ListModelsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[model_service.ListModelsResponse]], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.model_service.ListModelsRequest`): - The initial request object. - response (:class:`~.model_service.ListModelsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = model_service.ListModelsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[model_service.ListModelsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[model.Model]: - async def async_generator(): - async for page in self.pages: - for response in page.models: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - class ListModelEvaluationsPager: """A pager for iterating through ``list_model_evaluations`` requests. @@ -171,11 +102,12 @@ class ListModelEvaluationsPager: def __init__( self, - method: Callable[..., model_service.ListModelEvaluationsResponse], + method: Callable[ + [model_service.ListModelEvaluationsRequest], + model_service.ListModelEvaluationsResponse, + ], request: model_service.ListModelEvaluationsRequest, response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -186,13 +118,10 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelEvaluationsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -202,7 +131,7 @@ def pages(self) -> Iterable[model_service.ListModelEvaluationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: @@ -213,72 +142,6 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) -class ListModelEvaluationsAsyncPager: - """A pager for iterating through ``list_model_evaluations`` requests. - - This class thinly wraps an initial - :class:`~.model_service.ListModelEvaluationsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``model_evaluations`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListModelEvaluations`` requests and continue to iterate - through the ``model_evaluations`` field on the - corresponding responses. - - All the usual :class:`~.model_service.ListModelEvaluationsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], - request: model_service.ListModelEvaluationsRequest, - response: model_service.ListModelEvaluationsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.model_service.ListModelEvaluationsRequest`): - The initial request object. - response (:class:`~.model_service.ListModelEvaluationsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = model_service.ListModelEvaluationsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[model_evaluation.ModelEvaluation]: - async def async_generator(): - async for page in self.pages: - for response in page.model_evaluations: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - class ListModelEvaluationSlicesPager: """A pager for iterating through ``list_model_evaluation_slices`` requests. @@ -299,11 +162,12 @@ class ListModelEvaluationSlicesPager: def __init__( self, - method: Callable[..., model_service.ListModelEvaluationSlicesResponse], + method: Callable[ + [model_service.ListModelEvaluationSlicesRequest], + model_service.ListModelEvaluationSlicesResponse, + ], request: model_service.ListModelEvaluationSlicesRequest, response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -314,13 +178,10 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelEvaluationSlicesResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationSlicesRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -330,7 +191,7 @@ def pages(self) -> Iterable[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: @@ -339,73 +200,3 @@ def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - -class ListModelEvaluationSlicesAsyncPager: - """A pager for iterating through ``list_model_evaluation_slices`` requests. - - This class thinly wraps an initial - :class:`~.model_service.ListModelEvaluationSlicesResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``model_evaluation_slices`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListModelEvaluationSlices`` requests and continue to iterate - through the ``model_evaluation_slices`` field on the - corresponding responses. - - All the usual :class:`~.model_service.ListModelEvaluationSlicesResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[model_service.ListModelEvaluationSlicesResponse] - ], - request: model_service.ListModelEvaluationSlicesRequest, - response: model_service.ListModelEvaluationSlicesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): - The initial request object. - response (:class:`~.model_service.ListModelEvaluationSlicesResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = model_service.ListModelEvaluationSlicesRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages( - self, - ) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[model_evaluation_slice.ModelEvaluationSlice]: - async def async_generator(): - async for page in self.pages: - for response in page.model_evaluation_slices: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py index a521df9229..7bbcc75582 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py @@ -20,17 +20,14 @@ from .base import ModelServiceTransport from .grpc import ModelServiceGrpcTransport -from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] _transport_registry["grpc"] = ModelServiceGrpcTransport -_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport __all__ = ( "ModelServiceTransport", "ModelServiceGrpcTransport", - "ModelServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index 2f87fc98dd..53f94ea393 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -17,12 +17,8 @@ import abc import typing -import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore +from google import auth from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -34,17 +30,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -class ModelServiceTransport(abc.ABC): +class ModelServiceTransport(metaclass=abc.ABCMeta): """Abstract transport class for ModelService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -54,11 +40,6 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, ) -> None: """Instantiate the transport. @@ -69,17 +50,6 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -88,177 +58,89 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) - - if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) - - elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.upload_model: gapic_v1.method.wrap_method( - self.upload_model, default_timeout=5.0, client_info=client_info, - ), - self.get_model: gapic_v1.method.wrap_method( - self.get_model, default_timeout=5.0, client_info=client_info, - ), - self.list_models: gapic_v1.method.wrap_method( - self.list_models, default_timeout=5.0, client_info=client_info, - ), - self.update_model: gapic_v1.method.wrap_method( - self.update_model, default_timeout=5.0, client_info=client_info, - ), - self.delete_model: gapic_v1.method.wrap_method( - self.delete_model, default_timeout=5.0, client_info=client_info, - ), - self.export_model: gapic_v1.method.wrap_method( - self.export_model, default_timeout=5.0, client_info=client_info, - ), - self.get_model_evaluation: gapic_v1.method.wrap_method( - self.get_model_evaluation, default_timeout=5.0, client_info=client_info, - ), - self.list_model_evaluations: gapic_v1.method.wrap_method( - self.list_model_evaluations, - default_timeout=5.0, - client_info=client_info, - ), - self.get_model_evaluation_slice: gapic_v1.method.wrap_method( - self.get_model_evaluation_slice, - default_timeout=5.0, - client_info=client_info, - ), - self.list_model_evaluation_slices: gapic_v1.method.wrap_method( - self.list_model_evaluation_slices, - default_timeout=5.0, - client_info=client_info, - ), - } - @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError() + raise NotImplementedError @property def upload_model( self, - ) -> typing.Callable[ - [model_service.UploadModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[model_service.UploadModelRequest], operations.Operation]: + raise NotImplementedError @property def get_model( self, - ) -> typing.Callable[ - [model_service.GetModelRequest], - typing.Union[model.Model, typing.Awaitable[model.Model]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[model_service.GetModelRequest], model.Model]: + raise NotImplementedError @property def list_models( self, ) -> typing.Callable[ - [model_service.ListModelsRequest], - typing.Union[ - model_service.ListModelsResponse, - typing.Awaitable[model_service.ListModelsResponse], - ], + [model_service.ListModelsRequest], model_service.ListModelsResponse ]: - raise NotImplementedError() + raise NotImplementedError @property def update_model( self, - ) -> typing.Callable[ - [model_service.UpdateModelRequest], - typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[model_service.UpdateModelRequest], gca_model.Model]: + raise NotImplementedError @property def delete_model( self, - ) -> typing.Callable[ - [model_service.DeleteModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[model_service.DeleteModelRequest], operations.Operation]: + raise NotImplementedError @property def export_model( self, - ) -> typing.Callable[ - [model_service.ExportModelRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[model_service.ExportModelRequest], operations.Operation]: + raise NotImplementedError @property def get_model_evaluation( self, ) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], - typing.Union[ - model_evaluation.ModelEvaluation, - typing.Awaitable[model_evaluation.ModelEvaluation], - ], + [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation ]: - raise NotImplementedError() + raise NotImplementedError @property def list_model_evaluations( self, ) -> typing.Callable[ [model_service.ListModelEvaluationsRequest], - typing.Union[ - model_service.ListModelEvaluationsResponse, - typing.Awaitable[model_service.ListModelEvaluationsResponse], - ], + model_service.ListModelEvaluationsResponse, ]: - raise NotImplementedError() + raise NotImplementedError @property def get_model_evaluation_slice( self, ) -> typing.Callable[ [model_service.GetModelEvaluationSliceRequest], - typing.Union[ - model_evaluation_slice.ModelEvaluationSlice, - typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], - ], + model_evaluation_slice.ModelEvaluationSlice, ]: - raise NotImplementedError() + raise NotImplementedError @property def list_model_evaluation_slices( self, ) -> typing.Callable[ [model_service.ListModelEvaluationSlicesRequest], - typing.Union[ - model_service.ListModelEvaluationSlicesResponse, - typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], - ], + model_service.ListModelEvaluationSlicesResponse, ]: - raise NotImplementedError() + raise NotImplementedError __all__ = ("ModelServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index fa794283bb..f83c41e879 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -15,15 +15,11 @@ # limitations under the License. # -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -34,7 +30,7 @@ from google.cloud.aiplatform_v1beta1.types import model_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .base import ModelServiceTransport class ModelServiceGrpcTransport(ModelServiceTransport): @@ -50,21 +46,12 @@ class ModelServiceGrpcTransport(ModelServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + channel: grpc.Channel = None ) -> None: """Instantiate the transport. @@ -76,119 +63,28 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ + # Sanity check: Ensure that channel and credentials are not both + # provided. if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - + # Run the base constructor. + super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # If a channel was explicitly provided, set it. + if channel: + self._grpc_channel = channel @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, + **kwargs ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -198,31 +94,13 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, + host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs ) @property @@ -232,6 +110,13 @@ def grpc_channel(self) -> grpc.Channel: This property caches on the instance; repeated calls return the same channel. """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, "_grpc_channel"): + self._grpc_channel = self.create_channel( + self._host, credentials=self._credentials, + ) + # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py deleted file mode 100644 index ffe89774ef..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,534 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple - -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import model_evaluation -from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice -from google.cloud.aiplatform_v1beta1.types import model_service -from google.longrunning import operations_pb2 as operations # type: ignore - -from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import ModelServiceGrpcTransport - - -class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): - """gRPC AsyncIO backend transport for ModelService. - - A service for managing AI Platform's machine learning Models. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - address (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - scopes = scopes or cls.AUTH_SCOPES - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, - ) - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. - credentials = False - - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def operations_client(self) -> operations_v1.OperationsAsyncClient: - """Create the client designed to process long-running operations. - - This property caches on the instance; repeated calls return the same - client. - """ - # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( - self.grpc_channel - ) - - # Return the client from cache. - return self.__dict__["operations_client"] - - @property - def upload_model( - self, - ) -> Callable[[model_service.UploadModelRequest], Awaitable[operations.Operation]]: - r"""Return a callable for the upload model method over gRPC. - - Uploads a Model artifact into AI Platform. - - Returns: - Callable[[~.UploadModelRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "upload_model" not in self._stubs: - self._stubs["upload_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", - request_serializer=model_service.UploadModelRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["upload_model"] - - @property - def get_model( - self, - ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: - r"""Return a callable for the get model method over gRPC. - - Gets a Model. - - Returns: - Callable[[~.GetModelRequest], - Awaitable[~.Model]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_model" not in self._stubs: - self._stubs["get_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", - request_serializer=model_service.GetModelRequest.serialize, - response_deserializer=model.Model.deserialize, - ) - return self._stubs["get_model"] - - @property - def list_models( - self, - ) -> Callable[ - [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] - ]: - r"""Return a callable for the list models method over gRPC. - - Lists Models in a Location. - - Returns: - Callable[[~.ListModelsRequest], - Awaitable[~.ListModelsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_models" not in self._stubs: - self._stubs["list_models"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", - request_serializer=model_service.ListModelsRequest.serialize, - response_deserializer=model_service.ListModelsResponse.deserialize, - ) - return self._stubs["list_models"] - - @property - def update_model( - self, - ) -> Callable[[model_service.UpdateModelRequest], Awaitable[gca_model.Model]]: - r"""Return a callable for the update model method over gRPC. - - Updates a Model. - - Returns: - Callable[[~.UpdateModelRequest], - Awaitable[~.Model]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "update_model" not in self._stubs: - self._stubs["update_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", - request_serializer=model_service.UpdateModelRequest.serialize, - response_deserializer=gca_model.Model.deserialize, - ) - return self._stubs["update_model"] - - @property - def delete_model( - self, - ) -> Callable[[model_service.DeleteModelRequest], Awaitable[operations.Operation]]: - r"""Return a callable for the delete model method over gRPC. - - Deletes a Model. - Note: Model can only be deleted if there are no - DeployedModels created from it. - - Returns: - Callable[[~.DeleteModelRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_model" not in self._stubs: - self._stubs["delete_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", - request_serializer=model_service.DeleteModelRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_model"] - - @property - def export_model( - self, - ) -> Callable[[model_service.ExportModelRequest], Awaitable[operations.Operation]]: - r"""Return a callable for the export model method over gRPC. - - Exports a trained, exportable, Model to a location specified by - the user. A Model is considered to be exportable if it has at - least one [supported export - format][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. - - Returns: - Callable[[~.ExportModelRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "export_model" not in self._stubs: - self._stubs["export_model"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", - request_serializer=model_service.ExportModelRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["export_model"] - - @property - def get_model_evaluation( - self, - ) -> Callable[ - [model_service.GetModelEvaluationRequest], - Awaitable[model_evaluation.ModelEvaluation], - ]: - r"""Return a callable for the get model evaluation method over gRPC. - - Gets a ModelEvaluation. - - Returns: - Callable[[~.GetModelEvaluationRequest], - Awaitable[~.ModelEvaluation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_model_evaluation" not in self._stubs: - self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", - request_serializer=model_service.GetModelEvaluationRequest.serialize, - response_deserializer=model_evaluation.ModelEvaluation.deserialize, - ) - return self._stubs["get_model_evaluation"] - - @property - def list_model_evaluations( - self, - ) -> Callable[ - [model_service.ListModelEvaluationsRequest], - Awaitable[model_service.ListModelEvaluationsResponse], - ]: - r"""Return a callable for the list model evaluations method over gRPC. - - Lists ModelEvaluations in a Model. - - Returns: - Callable[[~.ListModelEvaluationsRequest], - Awaitable[~.ListModelEvaluationsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_model_evaluations" not in self._stubs: - self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", - request_serializer=model_service.ListModelEvaluationsRequest.serialize, - response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, - ) - return self._stubs["list_model_evaluations"] - - @property - def get_model_evaluation_slice( - self, - ) -> Callable[ - [model_service.GetModelEvaluationSliceRequest], - Awaitable[model_evaluation_slice.ModelEvaluationSlice], - ]: - r"""Return a callable for the get model evaluation slice method over gRPC. - - Gets a ModelEvaluationSlice. - - Returns: - Callable[[~.GetModelEvaluationSliceRequest], - Awaitable[~.ModelEvaluationSlice]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_model_evaluation_slice" not in self._stubs: - self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", - request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, - response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, - ) - return self._stubs["get_model_evaluation_slice"] - - @property - def list_model_evaluation_slices( - self, - ) -> Callable[ - [model_service.ListModelEvaluationSlicesRequest], - Awaitable[model_service.ListModelEvaluationSlicesResponse], - ]: - r"""Return a callable for the list model evaluation slices method over gRPC. - - Lists ModelEvaluationSlices in a ModelEvaluation. - - Returns: - Callable[[~.ListModelEvaluationSlicesRequest], - Awaitable[~.ListModelEvaluationSlicesResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_model_evaluation_slices" not in self._stubs: - self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", - request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, - response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, - ) - return self._stubs["list_model_evaluation_slices"] - - -__all__ = ("ModelServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py index 7f02b47358..29ba95e15f 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py @@ -16,9 +16,5 @@ # from .client import PipelineServiceClient -from .async_client import PipelineServiceAsyncClient -__all__ = ( - "PipelineServiceClient", - "PipelineServiceAsyncClient", -) +__all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py deleted file mode 100644 index 88912e928f..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ /dev/null @@ -1,551 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from collections import OrderedDict -import functools -import re -from typing import Dict, Sequence, Tuple, Type, Union -import pkg_resources - -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore -from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import pipeline_service -from google.cloud.aiplatform_v1beta1.types import pipeline_state -from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) -from google.protobuf import empty_pb2 as empty # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.rpc import status_pb2 as status # type: ignore - -from .transports.base import PipelineServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import PipelineServiceGrpcAsyncIOTransport -from .client import PipelineServiceClient - - -class PipelineServiceAsyncClient: - """A service for creating and managing AI Platform's pipelines.""" - - _client: PipelineServiceClient - - DEFAULT_ENDPOINT = PipelineServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = PipelineServiceClient.DEFAULT_MTLS_ENDPOINT - - model_path = staticmethod(PipelineServiceClient.model_path) - parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) - training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) - parse_training_pipeline_path = staticmethod( - PipelineServiceClient.parse_training_pipeline_path - ) - - from_service_account_file = PipelineServiceClient.from_service_account_file - from_service_account_json = from_service_account_file - - get_transport_class = functools.partial( - type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient) - ) - - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the pipeline service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.PipelineServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - - self._client = PipelineServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - ) - - async def create_training_pipeline( - self, - request: pipeline_service.CreateTrainingPipelineRequest = None, - *, - parent: str = None, - training_pipeline: gca_training_pipeline.TrainingPipeline = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gca_training_pipeline.TrainingPipeline: - r"""Creates a TrainingPipeline. A created - TrainingPipeline right away will be attempted to be run. - - Args: - request (:class:`~.pipeline_service.CreateTrainingPipelineRequest`): - The request object. Request message for - [PipelineService.CreateTrainingPipeline][google.cloud.aiplatform.v1beta1.PipelineService.CreateTrainingPipeline]. - parent (:class:`str`): - Required. The resource name of the Location to create - the TrainingPipeline in. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - training_pipeline (:class:`~.gca_training_pipeline.TrainingPipeline`): - Required. The TrainingPipeline to - create. - This corresponds to the ``training_pipeline`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.gca_training_pipeline.TrainingPipeline: - The TrainingPipeline orchestrates tasks associated with - training a Model. It always executes the training task, - and optionally may also export data from AI Platform's - Dataset which becomes the training input, - [upload][google.cloud.aiplatform.v1beta1.ModelService.UploadModel] - the Model to AI Platform, and evaluate the Model. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, training_pipeline]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = pipeline_service.CreateTrainingPipelineRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if training_pipeline is not None: - request.training_pipeline = training_pipeline - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_training_pipeline, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def get_training_pipeline( - self, - request: pipeline_service.GetTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> training_pipeline.TrainingPipeline: - r"""Gets a TrainingPipeline. - - Args: - request (:class:`~.pipeline_service.GetTrainingPipelineRequest`): - The request object. Request message for - [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1beta1.PipelineService.GetTrainingPipeline]. - name (:class:`str`): - Required. The name of the TrainingPipeline resource. - Format: - - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.training_pipeline.TrainingPipeline: - The TrainingPipeline orchestrates tasks associated with - training a Model. It always executes the training task, - and optionally may also export data from AI Platform's - Dataset which becomes the training input, - [upload][google.cloud.aiplatform.v1beta1.ModelService.UploadModel] - the Model to AI Platform, and evaluate the Model. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = pipeline_service.GetTrainingPipelineRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_training_pipeline, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_training_pipelines( - self, - request: pipeline_service.ListTrainingPipelinesRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTrainingPipelinesAsyncPager: - r"""Lists TrainingPipelines in a Location. - - Args: - request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): - The request object. Request message for - [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1beta1.PipelineService.ListTrainingPipelines]. - parent (:class:`str`): - Required. The resource name of the Location to list the - TrainingPipelines from. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListTrainingPipelinesAsyncPager: - Response message for - [PipelineService.ListTrainingPipelines][google.cloud.aiplatform.v1beta1.PipelineService.ListTrainingPipelines] - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = pipeline_service.ListTrainingPipelinesRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_training_pipelines, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListTrainingPipelinesAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def delete_training_pipeline( - self, - request: pipeline_service.DeleteTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes a TrainingPipeline. - - Args: - request (:class:`~.pipeline_service.DeleteTrainingPipelineRequest`): - The request object. Request message for - [PipelineService.DeleteTrainingPipeline][google.cloud.aiplatform.v1beta1.PipelineService.DeleteTrainingPipeline]. - name (:class:`str`): - Required. The name of the TrainingPipeline resource to - be deleted. Format: - - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = pipeline_service.DeleteTrainingPipelineRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_training_pipeline, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def cancel_training_pipeline( - self, - request: pipeline_service.CancelTrainingPipelineRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: - r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on - the TrainingPipeline. The server makes a best effort to cancel - the pipeline, but success is not guaranteed. Clients can use - [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1beta1.PipelineService.GetTrainingPipeline] - or other methods to check whether the cancellation succeeded or - whether the pipeline completed despite cancellation. On - successful cancellation, the TrainingPipeline is not deleted; - instead it becomes a pipeline with a - [TrainingPipeline.error][google.cloud.aiplatform.v1beta1.TrainingPipeline.error] - value with a [google.rpc.Status.code][google.rpc.Status.code] of - 1, corresponding to ``Code.CANCELLED``, and - [TrainingPipeline.state][google.cloud.aiplatform.v1beta1.TrainingPipeline.state] - is set to ``CANCELLED``. - - Args: - request (:class:`~.pipeline_service.CancelTrainingPipelineRequest`): - The request object. Request message for - [PipelineService.CancelTrainingPipeline][google.cloud.aiplatform.v1beta1.PipelineService.CancelTrainingPipeline]. - name (:class:`str`): - Required. The name of the TrainingPipeline to cancel. - Format: - - ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = pipeline_service.CancelTrainingPipelineRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.cancel_training_pipeline, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - await rpc( - request, retry=retry, timeout=timeout, metadata=metadata, - ) - - -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -__all__ = ("PipelineServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 739c1bb861..7da23adbc3 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -16,24 +16,17 @@ # from collections import OrderedDict -from distutils import util -import os -import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -from google.api_core import client_options as client_options_lib # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore +from google.api_core import operation as ga_operation from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import operation as gca_operation @@ -48,9 +41,8 @@ from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore -from .transports.base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .transports.base import PipelineServiceTransport from .transports.grpc import PipelineServiceGrpcTransport -from .transports.grpc_asyncio import PipelineServiceGrpcAsyncIOTransport class PipelineServiceClientMeta(type): @@ -65,7 +57,6 @@ class PipelineServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[PipelineServiceTransport]] _transport_registry["grpc"] = PipelineServiceGrpcTransport - _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. @@ -89,38 +80,8 @@ def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTranspor class PipelineServiceClient(metaclass=PipelineServiceClientMeta): """A service for creating and managing AI Platform's pipelines.""" - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT + DEFAULT_OPTIONS = ClientOptions.ClientOptions( + api_endpoint="aiplatform.googleapis.com" ) @classmethod @@ -143,22 +104,6 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file - @staticmethod - def model_path(project: str, location: str, model: str,) -> str: - """Return a fully-qualified model string.""" - return "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) - - @staticmethod - def parse_model_path(path: str) -> Dict[str, str]: - """Parse a model path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", - path, - ) - return m.groupdict() if m else {} - @staticmethod def training_pipeline_path( project: str, location: str, training_pipeline: str, @@ -169,21 +114,18 @@ def training_pipeline_path( ) @staticmethod - def parse_training_pipeline_path(path: str) -> Dict[str, str]: - """Parse a training_pipeline path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", - path, + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, ) - return m.groupdict() if m else {} def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PipelineServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = None, + client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, ) -> None: """Instantiate the pipeline service client. @@ -196,102 +138,26 @@ def __init__( transport (Union[str, ~.PipelineServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. + client_options (ClientOptions): Custom options for the client. """ if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - - # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) - - ssl_credentials = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - is_mtls = True - else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" - ) + client_options = ClientOptions.from_dict(client_options) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PipelineServiceTransport): - # transport is a PipelineServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: - raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." - ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, - quota_project_id=client_options.quota_project_id, - client_info=client_info, + host=client_options.api_endpoint or "aiplatform.googleapis.com", ) def create_training_pipeline( @@ -344,36 +210,28 @@ def create_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, training_pipeline]) - if request is not None and has_flattened_params: + if request is not None and any([parent, training_pipeline]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CreateTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, pipeline_service.CreateTrainingPipelineRequest): - request = pipeline_service.CreateTrainingPipelineRequest(request) + request = pipeline_service.CreateTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if training_pipeline is not None: - request.training_pipeline = training_pipeline + if parent is not None: + request.parent = parent + if training_pipeline is not None: + request.training_pipeline = training_pipeline # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_training_pipeline] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.create_training_pipeline, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -425,29 +283,27 @@ def get_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.GetTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, pipeline_service.GetTrainingPipelineRequest): - request = pipeline_service.GetTrainingPipelineRequest(request) + request = pipeline_service.GetTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_training_pipeline] + rpc = gapic_v1.method.wrap_method( + self._transport.get_training_pipeline, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -502,29 +358,27 @@ def list_training_pipelines( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.ListTrainingPipelinesRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, pipeline_service.ListTrainingPipelinesRequest): - request = pipeline_service.ListTrainingPipelinesRequest(request) + request = pipeline_service.ListTrainingPipelinesRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_training_pipelines] + rpc = gapic_v1.method.wrap_method( + self._transport.list_training_pipelines, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -538,7 +392,7 @@ def list_training_pipelines( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -598,34 +452,26 @@ def delete_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.DeleteTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, pipeline_service.DeleteTrainingPipelineRequest): - request = pipeline_service.DeleteTrainingPipelineRequest(request) + request = pipeline_service.DeleteTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_training_pipeline] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_training_pipeline, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -687,34 +533,26 @@ def cancel_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a pipeline_service.CancelTrainingPipelineRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, pipeline_service.CancelTrainingPipelineRequest): - request = pipeline_service.CancelTrainingPipelineRequest(request) + request = pipeline_service.CancelTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.cancel_training_pipeline] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.cancel_training_pipeline, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -724,13 +562,13 @@ def cancel_training_pipeline( try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + _client_info = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + _client_info = gapic_v1.client_info.ClientInfo() __all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index 98e5a51a17..0db54250ef 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import Any, Callable, Iterable from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -41,11 +41,12 @@ class ListTrainingPipelinesPager: def __init__( self, - method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], + method: Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + pipeline_service.ListTrainingPipelinesResponse, + ], request: pipeline_service.ListTrainingPipelinesRequest, response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -56,13 +57,10 @@ def __init__( The initial request object. response (:class:`~.pipeline_service.ListTrainingPipelinesResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListTrainingPipelinesRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -72,7 +70,7 @@ def pages(self) -> Iterable[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: @@ -81,73 +79,3 @@ def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - -class ListTrainingPipelinesAsyncPager: - """A pager for iterating through ``list_training_pipelines`` requests. - - This class thinly wraps an initial - :class:`~.pipeline_service.ListTrainingPipelinesResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``training_pipelines`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListTrainingPipelines`` requests and continue to iterate - through the ``training_pipelines`` field on the - corresponding responses. - - All the usual :class:`~.pipeline_service.ListTrainingPipelinesResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[pipeline_service.ListTrainingPipelinesResponse] - ], - request: pipeline_service.ListTrainingPipelinesRequest, - response: pipeline_service.ListTrainingPipelinesResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): - The initial request object. - response (:class:`~.pipeline_service.ListTrainingPipelinesResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = pipeline_service.ListTrainingPipelinesRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages( - self, - ) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[training_pipeline.TrainingPipeline]: - async def async_generator(): - async for page in self.pages: - for response in page.training_pipelines: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py index d9d71a892b..615b2c1025 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py @@ -20,17 +20,14 @@ from .base import PipelineServiceTransport from .grpc import PipelineServiceGrpcTransport -from .grpc_asyncio import PipelineServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] _transport_registry["grpc"] = PipelineServiceGrpcTransport -_transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport __all__ = ( "PipelineServiceTransport", "PipelineServiceGrpcTransport", - "PipelineServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 41123b8615..5696ede4d7 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -17,12 +17,8 @@ import abc import typing -import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore +from google import auth from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -35,17 +31,7 @@ from google.protobuf import empty_pb2 as empty # type: ignore -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -class PipelineServiceTransport(abc.ABC): +class PipelineServiceTransport(metaclass=abc.ABCMeta): """Abstract transport class for PipelineService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -55,11 +41,6 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, ) -> None: """Instantiate the transport. @@ -70,17 +51,6 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -89,115 +59,57 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) - - if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) - - elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.create_training_pipeline: gapic_v1.method.wrap_method( - self.create_training_pipeline, - default_timeout=5.0, - client_info=client_info, - ), - self.get_training_pipeline: gapic_v1.method.wrap_method( - self.get_training_pipeline, - default_timeout=5.0, - client_info=client_info, - ), - self.list_training_pipelines: gapic_v1.method.wrap_method( - self.list_training_pipelines, - default_timeout=5.0, - client_info=client_info, - ), - self.delete_training_pipeline: gapic_v1.method.wrap_method( - self.delete_training_pipeline, - default_timeout=5.0, - client_info=client_info, - ), - self.cancel_training_pipeline: gapic_v1.method.wrap_method( - self.cancel_training_pipeline, - default_timeout=5.0, - client_info=client_info, - ), - } - @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError() + raise NotImplementedError @property def create_training_pipeline( self, ) -> typing.Callable[ [pipeline_service.CreateTrainingPipelineRequest], - typing.Union[ - gca_training_pipeline.TrainingPipeline, - typing.Awaitable[gca_training_pipeline.TrainingPipeline], - ], + gca_training_pipeline.TrainingPipeline, ]: - raise NotImplementedError() + raise NotImplementedError @property def get_training_pipeline( self, ) -> typing.Callable[ [pipeline_service.GetTrainingPipelineRequest], - typing.Union[ - training_pipeline.TrainingPipeline, - typing.Awaitable[training_pipeline.TrainingPipeline], - ], + training_pipeline.TrainingPipeline, ]: - raise NotImplementedError() + raise NotImplementedError @property def list_training_pipelines( self, ) -> typing.Callable[ [pipeline_service.ListTrainingPipelinesRequest], - typing.Union[ - pipeline_service.ListTrainingPipelinesResponse, - typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse], - ], + pipeline_service.ListTrainingPipelinesResponse, ]: - raise NotImplementedError() + raise NotImplementedError @property def delete_training_pipeline( self, ) -> typing.Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError @property def cancel_training_pipeline( self, - ) -> typing.Callable[ - [pipeline_service.CancelTrainingPipelineRequest], - typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], - ]: - raise NotImplementedError() + ) -> typing.Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: + raise NotImplementedError __all__ = ("PipelineServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 1f2ead2426..5c79d9870d 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -15,15 +15,11 @@ # limitations under the License. # -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -35,7 +31,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .base import PipelineServiceTransport class PipelineServiceGrpcTransport(PipelineServiceTransport): @@ -51,21 +47,12 @@ class PipelineServiceGrpcTransport(PipelineServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + channel: grpc.Channel = None ) -> None: """Instantiate the transport. @@ -77,119 +64,28 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ + # Sanity check: Ensure that channel and credentials are not both + # provided. if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - + # Run the base constructor. + super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # If a channel was explicitly provided, set it. + if channel: + self._grpc_channel = channel @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, + **kwargs ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -199,31 +95,13 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, + host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs ) @property @@ -233,6 +111,13 @@ def grpc_channel(self) -> grpc.Channel: This property caches on the instance; repeated calls return the same channel. """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, "_grpc_channel"): + self._grpc_channel = self.create_channel( + self._host, credentials=self._credentials, + ) + # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py deleted file mode 100644 index 56fa96d91f..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,413 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple - -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.aiplatform_v1beta1.types import pipeline_service -from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) -from google.longrunning import operations_pb2 as operations # type: ignore -from google.protobuf import empty_pb2 as empty # type: ignore - -from .base import PipelineServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import PipelineServiceGrpcTransport - - -class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): - """gRPC AsyncIO backend transport for PipelineService. - - A service for creating and managing AI Platform's pipelines. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - address (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - scopes = scopes or cls.AUTH_SCOPES - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, - ) - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. - credentials = False - - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def operations_client(self) -> operations_v1.OperationsAsyncClient: - """Create the client designed to process long-running operations. - - This property caches on the instance; repeated calls return the same - client. - """ - # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( - self.grpc_channel - ) - - # Return the client from cache. - return self.__dict__["operations_client"] - - @property - def create_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CreateTrainingPipelineRequest], - Awaitable[gca_training_pipeline.TrainingPipeline], - ]: - r"""Return a callable for the create training pipeline method over gRPC. - - Creates a TrainingPipeline. A created - TrainingPipeline right away will be attempted to be run. - - Returns: - Callable[[~.CreateTrainingPipelineRequest], - Awaitable[~.TrainingPipeline]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "create_training_pipeline" not in self._stubs: - self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", - request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, - response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, - ) - return self._stubs["create_training_pipeline"] - - @property - def get_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.GetTrainingPipelineRequest], - Awaitable[training_pipeline.TrainingPipeline], - ]: - r"""Return a callable for the get training pipeline method over gRPC. - - Gets a TrainingPipeline. - - Returns: - Callable[[~.GetTrainingPipelineRequest], - Awaitable[~.TrainingPipeline]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_training_pipeline" not in self._stubs: - self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", - request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, - response_deserializer=training_pipeline.TrainingPipeline.deserialize, - ) - return self._stubs["get_training_pipeline"] - - @property - def list_training_pipelines( - self, - ) -> Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - Awaitable[pipeline_service.ListTrainingPipelinesResponse], - ]: - r"""Return a callable for the list training pipelines method over gRPC. - - Lists TrainingPipelines in a Location. - - Returns: - Callable[[~.ListTrainingPipelinesRequest], - Awaitable[~.ListTrainingPipelinesResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_training_pipelines" not in self._stubs: - self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", - request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, - response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, - ) - return self._stubs["list_training_pipelines"] - - @property - def delete_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], - Awaitable[operations.Operation], - ]: - r"""Return a callable for the delete training pipeline method over gRPC. - - Deletes a TrainingPipeline. - - Returns: - Callable[[~.DeleteTrainingPipelineRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_training_pipeline" not in self._stubs: - self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", - request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_training_pipeline"] - - @property - def cancel_training_pipeline( - self, - ) -> Callable[ - [pipeline_service.CancelTrainingPipelineRequest], Awaitable[empty.Empty] - ]: - r"""Return a callable for the cancel training pipeline method over gRPC. - - Cancels a TrainingPipeline. Starts asynchronous cancellation on - the TrainingPipeline. The server makes a best effort to cancel - the pipeline, but success is not guaranteed. Clients can use - [PipelineService.GetTrainingPipeline][google.cloud.aiplatform.v1beta1.PipelineService.GetTrainingPipeline] - or other methods to check whether the cancellation succeeded or - whether the pipeline completed despite cancellation. On - successful cancellation, the TrainingPipeline is not deleted; - instead it becomes a pipeline with a - [TrainingPipeline.error][google.cloud.aiplatform.v1beta1.TrainingPipeline.error] - value with a [google.rpc.Status.code][google.rpc.Status.code] of - 1, corresponding to ``Code.CANCELLED``, and - [TrainingPipeline.state][google.cloud.aiplatform.v1beta1.TrainingPipeline.state] - is set to ``CANCELLED``. - - Returns: - Callable[[~.CancelTrainingPipelineRequest], - Awaitable[~.Empty]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "cancel_training_pipeline" not in self._stubs: - self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", - request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, - response_deserializer=empty.Empty.FromString, - ) - return self._stubs["cancel_training_pipeline"] - - -__all__ = ("PipelineServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py index 0c847693e0..9e3af89360 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py @@ -16,9 +16,5 @@ # from .client import PredictionServiceClient -from .async_client import PredictionServiceAsyncClient -__all__ = ( - "PredictionServiceClient", - "PredictionServiceAsyncClient", -) +__all__ = ("PredictionServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py deleted file mode 100644 index 4f9ab1350c..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ /dev/null @@ -1,341 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from collections import OrderedDict -import functools -import re -from typing import Dict, Sequence, Tuple, Type, Union -import pkg_resources - -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import prediction_service -from google.protobuf import struct_pb2 as struct # type: ignore - -from .transports.base import PredictionServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport -from .client import PredictionServiceClient - - -class PredictionServiceAsyncClient: - """A service for online predictions and explanations.""" - - _client: PredictionServiceClient - - DEFAULT_ENDPOINT = PredictionServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = PredictionServiceClient.DEFAULT_MTLS_ENDPOINT - - from_service_account_file = PredictionServiceClient.from_service_account_file - from_service_account_json = from_service_account_file - - get_transport_class = functools.partial( - type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient) - ) - - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the prediction service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.PredictionServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - - self._client = PredictionServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - ) - - async def predict( - self, - request: prediction_service.PredictRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.PredictResponse: - r"""Perform an online prediction. - - Args: - request (:class:`~.prediction_service.PredictRequest`): - The request object. Request message for - [PredictionService.Predict][google.cloud.aiplatform.v1beta1.PredictionService.Predict]. - endpoint (:class:`str`): - Required. The name of the Endpoint requested to serve - the prediction. Format: - ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - This corresponds to the ``endpoint`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - instances (:class:`Sequence[~.struct.Value]`): - Required. The instances that are the input to the - prediction call. A DeployedModel may have an upper limit - on the number of instances it supports per request, and - when it is exceeded the prediction call errors in case - of AutoML Models, or, in case of customer created - Models, the behaviour is as documented by that Model. - The schema of any single instance may be specified via - Endpoint's DeployedModels' - [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] - [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] - [instance_schema_uri][google.cloud.aiplatform.v1beta1.PredictSchemata.instance_schema_uri]. - This corresponds to the ``instances`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - parameters (:class:`~.struct.Value`): - The parameters that govern the prediction. The schema of - the parameters may be specified via Endpoint's - DeployedModels' [Model's - ][google.cloud.aiplatform.v1beta1.DeployedModel.model] - [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] - [parameters_schema_uri][google.cloud.aiplatform.v1beta1.PredictSchemata.parameters_schema_uri]. - This corresponds to the ``parameters`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.prediction_service.PredictResponse: - Response message for - [PredictionService.Predict][google.cloud.aiplatform.v1beta1.PredictionService.Predict]. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, instances, parameters]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = prediction_service.PredictRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if endpoint is not None: - request.endpoint = endpoint - if instances is not None: - request.instances = instances - if parameters is not None: - request.parameters = parameters - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.predict, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def explain( - self, - request: prediction_service.ExplainRequest = None, - *, - endpoint: str = None, - instances: Sequence[struct.Value] = None, - parameters: struct.Value = None, - deployed_model_id: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> prediction_service.ExplainResponse: - r"""Perform an online explanation. - - If [ExplainRequest.deployed_model_id] is specified, the - corresponding DeployModel must have - [explanation_spec][google.cloud.aiplatform.v1beta1.DeployedModel.explanation_spec] - populated. If [ExplainRequest.deployed_model_id] is not - specified, all DeployedModels must have - [explanation_spec][google.cloud.aiplatform.v1beta1.DeployedModel.explanation_spec] - populated. Only deployed AutoML tabular Models have - explanation_spec. - - Args: - request (:class:`~.prediction_service.ExplainRequest`): - The request object. Request message for - [PredictionService.Explain][google.cloud.aiplatform.v1beta1.PredictionService.Explain]. - endpoint (:class:`str`): - Required. The name of the Endpoint requested to serve - the explanation. Format: - ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - This corresponds to the ``endpoint`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - instances (:class:`Sequence[~.struct.Value]`): - Required. The instances that are the input to the - explanation call. A DeployedModel may have an upper - limit on the number of instances it supports per - request, and when it is exceeded the explanation call - errors in case of AutoML Models, or, in case of customer - created Models, the behaviour is as documented by that - Model. The schema of any single instance may be - specified via Endpoint's DeployedModels' - [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] - [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] - [instance_schema_uri][google.cloud.aiplatform.v1beta1.PredictSchemata.instance_schema_uri]. - This corresponds to the ``instances`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - parameters (:class:`~.struct.Value`): - The parameters that govern the prediction. The schema of - the parameters may be specified via Endpoint's - DeployedModels' [Model's - ][google.cloud.aiplatform.v1beta1.DeployedModel.model] - [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] - [parameters_schema_uri][google.cloud.aiplatform.v1beta1.PredictSchemata.parameters_schema_uri]. - This corresponds to the ``parameters`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - deployed_model_id (:class:`str`): - If specified, this ExplainRequest will be served by the - chosen DeployedModel, overriding - [Endpoint.traffic_split][google.cloud.aiplatform.v1beta1.Endpoint.traffic_split]. - This corresponds to the ``deployed_model_id`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.prediction_service.ExplainResponse: - Response message for - [PredictionService.Explain][google.cloud.aiplatform.v1beta1.PredictionService.Explain]. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any( - [endpoint, instances, parameters, deployed_model_id] - ): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = prediction_service.ExplainRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if endpoint is not None: - request.endpoint = endpoint - if instances is not None: - request.instances = instances - if parameters is not None: - request.parameters = parameters - if deployed_model_id is not None: - request.deployed_model_id = deployed_model_id - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.explain, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -__all__ = ("PredictionServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index d96acc576b..c255e0619d 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -16,29 +16,22 @@ # from collections import OrderedDict -from distutils import util -import os -import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -from google.api_core import client_options as client_options_lib # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service from google.protobuf import struct_pb2 as struct # type: ignore -from .transports.base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .transports.base import PredictionServiceTransport from .transports.grpc import PredictionServiceGrpcTransport -from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport class PredictionServiceClientMeta(type): @@ -53,7 +46,6 @@ class PredictionServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[PredictionServiceTransport]] _transport_registry["grpc"] = PredictionServiceGrpcTransport - _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport def get_transport_class( cls, label: str = None, @@ -79,38 +71,8 @@ def get_transport_class( class PredictionServiceClient(metaclass=PredictionServiceClientMeta): """A service for online predictions and explanations.""" - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT + DEFAULT_OPTIONS = ClientOptions.ClientOptions( + api_endpoint="aiplatform.googleapis.com" ) @classmethod @@ -136,10 +98,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, PredictionServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = None, + client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, ) -> None: """Instantiate the prediction service client. @@ -152,102 +113,26 @@ def __init__( transport (Union[str, ~.PredictionServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. + client_options (ClientOptions): Custom options for the client. """ if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - - # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) - - ssl_credentials = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - is_mtls = True - else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" - ) + client_options = ClientOptions.from_dict(client_options) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PredictionServiceTransport): - # transport is a PredictionServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: - raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." - ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, - quota_project_id=client_options.quota_project_id, - client_info=client_info, + host=client_options.api_endpoint or "aiplatform.googleapis.com", ) def predict( @@ -315,38 +200,28 @@ def predict( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, instances, parameters]) - if request is not None and has_flattened_params: + if request is not None and any([endpoint, instances, parameters]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.PredictRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, prediction_service.PredictRequest): - request = prediction_service.PredictRequest(request) + request = prediction_service.PredictRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if instances is not None: - request.instances = instances - if parameters is not None: - request.parameters = parameters + if endpoint is not None: + request.endpoint = endpoint + if instances is not None: + request.instances = instances + if parameters is not None: + request.parameters = parameters # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.predict] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + rpc = gapic_v1.method.wrap_method( + self._transport.predict, default_timeout=None, client_info=_client_info, ) # Send the request. @@ -437,40 +312,32 @@ def explain( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) - if request is not None and has_flattened_params: + if request is not None and any( + [endpoint, instances, parameters, deployed_model_id] + ): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a prediction_service.ExplainRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, prediction_service.ExplainRequest): - request = prediction_service.ExplainRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if endpoint is not None: - request.endpoint = endpoint - if instances is not None: - request.instances = instances - if parameters is not None: - request.parameters = parameters - if deployed_model_id is not None: - request.deployed_model_id = deployed_model_id + request = prediction_service.ExplainRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if instances is not None: + request.instances = instances + if parameters is not None: + request.parameters = parameters + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.explain] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + rpc = gapic_v1.method.wrap_method( + self._transport.explain, default_timeout=None, client_info=_client_info, ) # Send the request. @@ -481,13 +348,13 @@ def explain( try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + _client_info = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + _client_info = gapic_v1.client_info.ClientInfo() __all__ = ("PredictionServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py index 7eb32ea86d..33eefca757 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py @@ -20,17 +20,14 @@ from .base import PredictionServiceTransport from .grpc import PredictionServiceGrpcTransport -from .grpc_asyncio import PredictionServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] _transport_registry["grpc"] = PredictionServiceGrpcTransport -_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport __all__ = ( "PredictionServiceTransport", "PredictionServiceGrpcTransport", - "PredictionServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index 0c82f7d83c..58d508474a 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -17,28 +17,14 @@ import abc import typing -import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore +from google import auth from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -class PredictionServiceTransport(abc.ABC): +class PredictionServiceTransport(metaclass=abc.ABCMeta): """Abstract transport class for PredictionService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -48,11 +34,6 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, ) -> None: """Instantiate the transport. @@ -63,17 +44,6 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -82,61 +52,27 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) - - if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) - - elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.predict: gapic_v1.method.wrap_method( - self.predict, default_timeout=5.0, client_info=client_info, - ), - self.explain: gapic_v1.method.wrap_method( - self.explain, default_timeout=5.0, client_info=client_info, - ), - } - @property def predict( self, ) -> typing.Callable[ - [prediction_service.PredictRequest], - typing.Union[ - prediction_service.PredictResponse, - typing.Awaitable[prediction_service.PredictResponse], - ], + [prediction_service.PredictRequest], prediction_service.PredictResponse ]: - raise NotImplementedError() + raise NotImplementedError @property def explain( self, ) -> typing.Callable[ - [prediction_service.ExplainRequest], - typing.Union[ - prediction_service.ExplainResponse, - typing.Awaitable[prediction_service.ExplainResponse], - ], + [prediction_service.ExplainRequest], prediction_service.ExplainResponse ]: - raise NotImplementedError() + raise NotImplementedError __all__ = ("PredictionServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index 5212c81f5a..b657bcaa16 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -15,20 +15,16 @@ # limitations under the License. # -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict from google.api_core import grpc_helpers # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service -from .base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .base import PredictionServiceTransport class PredictionServiceGrpcTransport(PredictionServiceTransport): @@ -44,21 +40,12 @@ class PredictionServiceGrpcTransport(PredictionServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + channel: grpc.Channel = None ) -> None: """Instantiate the transport. @@ -70,119 +57,28 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ + # Sanity check: Ensure that channel and credentials are not both + # provided. if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - + # Run the base constructor. + super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # If a channel was explicitly provided, set it. + if channel: + self._grpc_channel = channel @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, + **kwargs ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -192,31 +88,13 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, + host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs ) @property @@ -226,6 +104,13 @@ def grpc_channel(self) -> grpc.Channel: This property caches on the instance; repeated calls return the same channel. """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, "_grpc_channel"): + self._grpc_channel = self.create_channel( + self._host, credentials=self._credentials, + ) + # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py deleted file mode 100644 index 4107899bed..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,300 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple - -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.aiplatform_v1beta1.types import prediction_service - -from .base import PredictionServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import PredictionServiceGrpcTransport - - -class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): - """gRPC AsyncIO backend transport for PredictionService. - - A service for online predictions and explanations. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - address (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - scopes = scopes or cls.AUTH_SCOPES - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, - ) - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. - credentials = False - - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def predict( - self, - ) -> Callable[ - [prediction_service.PredictRequest], - Awaitable[prediction_service.PredictResponse], - ]: - r"""Return a callable for the predict method over gRPC. - - Perform an online prediction. - - Returns: - Callable[[~.PredictRequest], - Awaitable[~.PredictResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "predict" not in self._stubs: - self._stubs["predict"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", - request_serializer=prediction_service.PredictRequest.serialize, - response_deserializer=prediction_service.PredictResponse.deserialize, - ) - return self._stubs["predict"] - - @property - def explain( - self, - ) -> Callable[ - [prediction_service.ExplainRequest], - Awaitable[prediction_service.ExplainResponse], - ]: - r"""Return a callable for the explain method over gRPC. - - Perform an online explanation. - - If [ExplainRequest.deployed_model_id] is specified, the - corresponding DeployModel must have - [explanation_spec][google.cloud.aiplatform.v1beta1.DeployedModel.explanation_spec] - populated. If [ExplainRequest.deployed_model_id] is not - specified, all DeployedModels must have - [explanation_spec][google.cloud.aiplatform.v1beta1.DeployedModel.explanation_spec] - populated. Only deployed AutoML tabular Models have - explanation_spec. - - Returns: - Callable[[~.ExplainRequest], - Awaitable[~.ExplainResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "explain" not in self._stubs: - self._stubs["explain"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", - request_serializer=prediction_service.ExplainRequest.serialize, - response_deserializer=prediction_service.ExplainResponse.deserialize, - ) - return self._stubs["explain"] - - -__all__ = ("PredictionServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py index 49e9cdf0a0..8f429cd5eb 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py @@ -16,9 +16,5 @@ # from .client import SpecialistPoolServiceClient -from .async_client import SpecialistPoolServiceAsyncClient -__all__ = ( - "SpecialistPoolServiceClient", - "SpecialistPoolServiceAsyncClient", -) +__all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py deleted file mode 100644 index bce2179917..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ /dev/null @@ -1,592 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from collections import OrderedDict -import functools -import re -from typing import Dict, Sequence, Tuple, Type, Union -import pkg_resources - -import google.api_core.client_options as ClientOptions # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore -from google.auth import credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool_service -from google.protobuf import empty_pb2 as empty # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore - -from .transports.base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport -from .client import SpecialistPoolServiceClient - - -class SpecialistPoolServiceAsyncClient: - """A service for creating and managing Customer SpecialistPools. - When customers start Data Labeling jobs, they can reuse/create - Specialist Pools to bring their own Specialists to label the - data. Customers can add/remove Managers for the Specialist Pool - on Cloud console, then Managers will get email notifications to - manage Specialists and tasks on CrowdCompute console. - """ - - _client: SpecialistPoolServiceClient - - DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT - - specialist_pool_path = staticmethod( - SpecialistPoolServiceClient.specialist_pool_path - ) - parse_specialist_pool_path = staticmethod( - SpecialistPoolServiceClient.parse_specialist_pool_path - ) - - from_service_account_file = SpecialistPoolServiceClient.from_service_account_file - from_service_account_json = from_service_account_file - - get_transport_class = functools.partial( - type(SpecialistPoolServiceClient).get_transport_class, - type(SpecialistPoolServiceClient), - ) - - def __init__( - self, - *, - credentials: credentials.Credentials = None, - transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", - client_options: ClientOptions = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the specialist pool service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.SpecialistPoolServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - - self._client = SpecialistPoolServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - ) - - async def create_specialist_pool( - self, - request: specialist_pool_service.CreateSpecialistPoolRequest = None, - *, - parent: str = None, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Creates a SpecialistPool. - - Args: - request (:class:`~.specialist_pool_service.CreateSpecialistPoolRequest`): - The request object. Request message for - [SpecialistPoolService.CreateSpecialistPool][google.cloud.aiplatform.v1beta1.SpecialistPoolService.CreateSpecialistPool]. - parent (:class:`str`): - Required. The parent Project name for the new - SpecialistPool. The form is - ``projects/{project}/locations/{location}``. - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): - Required. The SpecialistPool to - create. - This corresponds to the ``specialist_pool`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.gca_specialist_pool.SpecialistPool``: - SpecialistPool represents customers' own workforce to - work on their data labeling jobs. It includes a group of - specialist managers who are responsible for managing the - labelers in this pool as well as customers' data - labeling jobs associated with this pool. Customers - create specialist pool as well as start data labeling - jobs on Cloud, managers and labelers work with the jobs - using CrowdCompute console. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent, specialist_pool]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = specialist_pool_service.CreateSpecialistPoolRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - if specialist_pool is not None: - request.specialist_pool = specialist_pool - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.create_specialist_pool, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - gca_specialist_pool.SpecialistPool, - metadata_type=specialist_pool_service.CreateSpecialistPoolOperationMetadata, - ) - - # Done; return the response. - return response - - async def get_specialist_pool( - self, - request: specialist_pool_service.GetSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> specialist_pool.SpecialistPool: - r"""Gets a SpecialistPool. - - Args: - request (:class:`~.specialist_pool_service.GetSpecialistPoolRequest`): - The request object. Request message for - [SpecialistPoolService.GetSpecialistPool][google.cloud.aiplatform.v1beta1.SpecialistPoolService.GetSpecialistPool]. - name (:class:`str`): - Required. The name of the SpecialistPool resource. The - form is - - ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.specialist_pool.SpecialistPool: - SpecialistPool represents customers' - own workforce to work on their data - labeling jobs. It includes a group of - specialist managers who are responsible - for managing the labelers in this pool - as well as customers' data labeling jobs - associated with this pool. - Customers create specialist pool as well - as start data labeling jobs on Cloud, - managers and labelers work with the jobs - using CrowdCompute console. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = specialist_pool_service.GetSpecialistPoolRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_specialist_pool, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Done; return the response. - return response - - async def list_specialist_pools( - self, - request: specialist_pool_service.ListSpecialistPoolsRequest = None, - *, - parent: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListSpecialistPoolsAsyncPager: - r"""Lists SpecialistPools in a Location. - - Args: - request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): - The request object. Request message for - [SpecialistPoolService.ListSpecialistPools][google.cloud.aiplatform.v1beta1.SpecialistPoolService.ListSpecialistPools]. - parent (:class:`str`): - Required. The name of the SpecialistPool's parent - resource. Format: - ``projects/{project}/locations/{location}`` - This corresponds to the ``parent`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.pagers.ListSpecialistPoolsAsyncPager: - Response message for - [SpecialistPoolService.ListSpecialistPools][google.cloud.aiplatform.v1beta1.SpecialistPoolService.ListSpecialistPools]. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = specialist_pool_service.ListSpecialistPoolsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if parent is not None: - request.parent = parent - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_specialist_pools, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListSpecialistPoolsAsyncPager( - method=rpc, request=request, response=response, metadata=metadata, - ) - - # Done; return the response. - return response - - async def delete_specialist_pool( - self, - request: specialist_pool_service.DeleteSpecialistPoolRequest = None, - *, - name: str = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Deletes a SpecialistPool as well as all Specialists - in the pool. - - Args: - request (:class:`~.specialist_pool_service.DeleteSpecialistPoolRequest`): - The request object. Request message for - [SpecialistPoolService.DeleteSpecialistPool][google.cloud.aiplatform.v1beta1.SpecialistPoolService.DeleteSpecialistPool]. - name (:class:`str`): - Required. The resource name of the SpecialistPool to - delete. Format: - ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}`` - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.empty.Empty``: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: - - :: - - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } - - The JSON representation for ``Empty`` is empty JSON - object ``{}``. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([name]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = specialist_pool_service.DeleteSpecialistPoolRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.delete_specialist_pool, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - empty.Empty, - metadata_type=gca_operation.DeleteOperationMetadata, - ) - - # Done; return the response. - return response - - async def update_specialist_pool( - self, - request: specialist_pool_service.UpdateSpecialistPoolRequest = None, - *, - specialist_pool: gca_specialist_pool.SpecialistPool = None, - update_mask: field_mask.FieldMask = None, - retry: retries.Retry = gapic_v1.method.DEFAULT, - timeout: float = None, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: - r"""Updates a SpecialistPool. - - Args: - request (:class:`~.specialist_pool_service.UpdateSpecialistPoolRequest`): - The request object. Request message for - [SpecialistPoolService.UpdateSpecialistPool][google.cloud.aiplatform.v1beta1.SpecialistPoolService.UpdateSpecialistPool]. - specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): - Required. The SpecialistPool which - replaces the resource on the server. - This corresponds to the ``specialist_pool`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - update_mask (:class:`~.field_mask.FieldMask`): - Required. The update mask applies to - the resource. - This corresponds to the ``update_mask`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.operation_async.AsyncOperation: - An object representing a long-running operation. - - The result type for the operation will be - :class:``~.gca_specialist_pool.SpecialistPool``: - SpecialistPool represents customers' own workforce to - work on their data labeling jobs. It includes a group of - specialist managers who are responsible for managing the - labelers in this pool as well as customers' data - labeling jobs associated with this pool. Customers - create specialist pool as well as start data labeling - jobs on Cloud, managers and labelers work with the jobs - using CrowdCompute console. - - """ - # Create or coerce a protobuf request object. - # Sanity check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - if request is not None and any([specialist_pool, update_mask]): - raise ValueError( - "If the `request` argument is set, then none of " - "the individual field arguments should be set." - ) - - request = specialist_pool_service.UpdateSpecialistPoolRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - - if specialist_pool is not None: - request.specialist_pool = specialist_pool - if update_mask is not None: - request.update_mask = update_mask - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.update_specialist_pool, - default_timeout=5.0, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("specialist_pool.name", request.specialist_pool.name),) - ), - ) - - # Send the request. - response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) - - # Wrap the response in an operation future. - response = operation_async.from_gapic( - response, - self._client._transport.operations_client, - gca_specialist_pool.SpecialistPool, - metadata_type=specialist_pool_service.UpdateSpecialistPoolOperationMetadata, - ) - - # Done; return the response. - return response - - -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -__all__ = ("SpecialistPoolServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index de5e846a34..b0f7bb38c8 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -16,24 +16,17 @@ # from collections import OrderedDict -from distutils import util -import os -import re -from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +from typing import Dict, Sequence, Tuple, Type, Union import pkg_resources -from google.api_core import client_options as client_options_lib # type: ignore +import google.api_core.client_options as ClientOptions # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation # type: ignore -from google.api_core import operation_async # type: ignore +from google.api_core import operation as ga_operation from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import specialist_pool @@ -42,9 +35,8 @@ from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from .transports.base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .transports.base import SpecialistPoolServiceTransport from .transports.grpc import SpecialistPoolServiceGrpcTransport -from .transports.grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport class SpecialistPoolServiceClientMeta(type): @@ -59,7 +51,6 @@ class SpecialistPoolServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport - _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport def get_transport_class( cls, label: str = None, @@ -91,38 +82,8 @@ class SpecialistPoolServiceClient(metaclass=SpecialistPoolServiceClientMeta): manage Specialists and tasks on CrowdCompute console. """ - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Convert api endpoint to mTLS endpoint. - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "aiplatform.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT + DEFAULT_OPTIONS = ClientOptions.ClientOptions( + api_endpoint="aiplatform.googleapis.com" ) @classmethod @@ -152,22 +113,12 @@ def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> project=project, location=location, specialist_pool=specialist_pool, ) - @staticmethod - def parse_specialist_pool_path(path: str) -> Dict[str, str]: - """Parse a specialist_pool path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", - path, - ) - return m.groupdict() if m else {} - def __init__( self, *, - credentials: Optional[credentials.Credentials] = None, - transport: Union[str, SpecialistPoolServiceTransport, None] = None, - client_options: Optional[client_options_lib.ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = None, + client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, ) -> None: """Instantiate the specialist pool service client. @@ -180,102 +131,26 @@ def __init__( transport (Union[str, ~.SpecialistPoolServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. + client_options (ClientOptions): Custom options for the client. """ if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - - # Create SSL credentials for mutual TLS if needed. - use_client_cert = bool( - util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) - ) - - ssl_credentials = None - is_mtls = False - if use_client_cert: - if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - is_mtls = True - else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - else: - use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_mtls_env == "never": - api_endpoint = self.DEFAULT_ENDPOINT - elif use_mtls_env == "always": - api_endpoint = self.DEFAULT_MTLS_ENDPOINT - elif use_mtls_env == "auto": - api_endpoint = ( - self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT - ) - else: - raise MutualTLSChannelError( - "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" - ) + client_options = ClientOptions.from_dict(client_options) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, SpecialistPoolServiceTransport): - # transport is a SpecialistPoolServiceTransport instance. - if credentials or client_options.credentials_file: + if credentials: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) - if client_options.scopes: - raise ValueError( - "When providing a transport instance, " - "provide its scopes directly." - ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, - quota_project_id=client_options.quota_project_id, - client_info=client_info, + host=client_options.api_endpoint or "aiplatform.googleapis.com", ) def create_specialist_pool( @@ -333,36 +208,28 @@ def create_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent, specialist_pool]) - if request is not None and has_flattened_params: + if request is not None and any([parent, specialist_pool]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.CreateSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, specialist_pool_service.CreateSpecialistPoolRequest): - request = specialist_pool_service.CreateSpecialistPoolRequest(request) + request = specialist_pool_service.CreateSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if specialist_pool is not None: - request.specialist_pool = specialist_pool + if parent is not None: + request.parent = parent + if specialist_pool is not None: + request.specialist_pool = specialist_pool # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.create_specialist_pool] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + rpc = gapic_v1.method.wrap_method( + self._transport.create_specialist_pool, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -427,29 +294,27 @@ def get_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.GetSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, specialist_pool_service.GetSpecialistPoolRequest): - request = specialist_pool_service.GetSpecialistPoolRequest(request) + request = specialist_pool_service.GetSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_specialist_pool] + rpc = gapic_v1.method.wrap_method( + self._transport.get_specialist_pool, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -504,29 +369,27 @@ def list_specialist_pools( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([parent]) - if request is not None and has_flattened_params: + if request is not None and any([parent]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.ListSpecialistPoolsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, specialist_pool_service.ListSpecialistPoolsRequest): - request = specialist_pool_service.ListSpecialistPoolsRequest(request) + request = specialist_pool_service.ListSpecialistPoolsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_specialist_pools] + rpc = gapic_v1.method.wrap_method( + self._transport.list_specialist_pools, + default_timeout=None, + client_info=_client_info, + ) # Certain fields should be provided within the metadata header; # add these here. @@ -540,7 +403,7 @@ def list_specialist_pools( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, request=request, response=response, metadata=metadata, + method=rpc, request=request, response=response, ) # Done; return the response. @@ -600,34 +463,26 @@ def delete_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: + if request is not None and any([name]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.DeleteSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, specialist_pool_service.DeleteSpecialistPoolRequest): - request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.delete_specialist_pool] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + rpc = gapic_v1.method.wrap_method( + self._transport.delete_specialist_pool, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -698,38 +553,28 @@ def update_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([specialist_pool, update_mask]) - if request is not None and has_flattened_params: + if request is not None and any([specialist_pool, update_mask]): raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - # Minor optimization to avoid making a copy if the user passes - # in a specialist_pool_service.UpdateSpecialistPoolRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, specialist_pool_service.UpdateSpecialistPoolRequest): - request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if specialist_pool is not None: - request.specialist_pool = specialist_pool - if update_mask is not None: - request.update_mask = update_mask + if specialist_pool is not None: + request.specialist_pool = specialist_pool + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.update_specialist_pool] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata( - (("specialist_pool.name", request.specialist_pool.name),) - ), + rpc = gapic_v1.method.wrap_method( + self._transport.update_specialist_pool, + default_timeout=None, + client_info=_client_info, ) # Send the request. @@ -748,13 +593,13 @@ def update_specialist_pool( try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + _client_info = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + _client_info = gapic_v1.client_info.ClientInfo() __all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index ff2d84ac74..012b76479b 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple +from typing import Any, Callable, Iterable from google.cloud.aiplatform_v1beta1.types import specialist_pool from google.cloud.aiplatform_v1beta1.types import specialist_pool_service @@ -41,11 +41,12 @@ class ListSpecialistPoolsPager: def __init__( self, - method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], + method: Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + specialist_pool_service.ListSpecialistPoolsResponse, + ], request: specialist_pool_service.ListSpecialistPoolsRequest, response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -56,13 +57,10 @@ def __init__( The initial request object. response (:class:`~.specialist_pool_service.ListSpecialistPoolsResponse`): The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. """ self._method = method self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) self._response = response - self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -72,7 +70,7 @@ def pages(self) -> Iterable[specialist_pool_service.ListSpecialistPoolsResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) + self._response = self._method(self._request) yield self._response def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: @@ -81,73 +79,3 @@ def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) - - -class ListSpecialistPoolsAsyncPager: - """A pager for iterating through ``list_specialist_pools`` requests. - - This class thinly wraps an initial - :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``specialist_pools`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListSpecialistPools`` requests and continue to iterate - through the ``specialist_pools`` field on the - corresponding responses. - - All the usual :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - - def __init__( - self, - method: Callable[ - ..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] - ], - request: specialist_pool_service.ListSpecialistPoolsRequest, - response: specialist_pool_service.ListSpecialistPoolsResponse, - *, - metadata: Sequence[Tuple[str, str]] = () - ): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): - The initial request object. - response (:class:`~.specialist_pool_service.ListSpecialistPoolsResponse`): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages( - self, - ) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - - def __aiter__(self) -> AsyncIterable[specialist_pool.SpecialistPool]: - async def async_generator(): - async for page in self.pages: - for response in page.specialist_pools: - yield response - - return async_generator() - - def __repr__(self) -> str: - return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py index 711f7fd1cc..c77d2d31a3 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py @@ -20,7 +20,6 @@ from .base import SpecialistPoolServiceTransport from .grpc import SpecialistPoolServiceGrpcTransport -from .grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport # Compile a registry of transports. @@ -28,11 +27,9 @@ OrderedDict() ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport -_transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport __all__ = ( "SpecialistPoolServiceTransport", "SpecialistPoolServiceGrpcTransport", - "SpecialistPoolServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index f1af058030..effe36767e 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -17,12 +17,8 @@ import abc import typing -import pkg_resources -from google import auth # type: ignore -from google.api_core import exceptions # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google.api_core import retry as retries # type: ignore +from google import auth from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -31,17 +27,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore -try: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=pkg_resources.get_distribution( - "google-cloud-aiplatform", - ).version, - ) -except pkg_resources.DistributionNotFound: - DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() - - -class SpecialistPoolServiceTransport(abc.ABC): +class SpecialistPoolServiceTransport(metaclass=abc.ABCMeta): """Abstract transport class for SpecialistPoolService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -51,11 +37,6 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: typing.Optional[str] = None, - scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, - quota_project_id: typing.Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - **kwargs, ) -> None: """Instantiate the transport. @@ -66,17 +47,6 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scope (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -85,110 +55,58 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials and credentials_file: - raise exceptions.DuplicateCredentialArgs( - "'credentials_file' and 'credentials' are mutually exclusive" - ) - - if credentials_file is not None: - credentials, _ = auth.load_credentials_from_file( - credentials_file, scopes=scopes, quota_project_id=quota_project_id - ) - - elif credentials is None: - credentials, _ = auth.default( - scopes=scopes, quota_project_id=quota_project_id - ) + if credentials is None: + credentials, _ = auth.default(scopes=self.AUTH_SCOPES) # Save the credentials. self._credentials = credentials - # Lifted into its own function so it can be stubbed out during tests. - self._prep_wrapped_messages(client_info) - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.create_specialist_pool: gapic_v1.method.wrap_method( - self.create_specialist_pool, - default_timeout=5.0, - client_info=client_info, - ), - self.get_specialist_pool: gapic_v1.method.wrap_method( - self.get_specialist_pool, default_timeout=5.0, client_info=client_info, - ), - self.list_specialist_pools: gapic_v1.method.wrap_method( - self.list_specialist_pools, - default_timeout=5.0, - client_info=client_info, - ), - self.delete_specialist_pool: gapic_v1.method.wrap_method( - self.delete_specialist_pool, - default_timeout=5.0, - client_info=client_info, - ), - self.update_specialist_pool: gapic_v1.method.wrap_method( - self.update_specialist_pool, - default_timeout=5.0, - client_info=client_info, - ), - } - @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError() + raise NotImplementedError @property def create_specialist_pool( self, ) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError @property def get_specialist_pool( self, ) -> typing.Callable[ [specialist_pool_service.GetSpecialistPoolRequest], - typing.Union[ - specialist_pool.SpecialistPool, - typing.Awaitable[specialist_pool.SpecialistPool], - ], + specialist_pool.SpecialistPool, ]: - raise NotImplementedError() + raise NotImplementedError @property def list_specialist_pools( self, ) -> typing.Callable[ [specialist_pool_service.ListSpecialistPoolsRequest], - typing.Union[ - specialist_pool_service.ListSpecialistPoolsResponse, - typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], - ], + specialist_pool_service.ListSpecialistPoolsResponse, ]: - raise NotImplementedError() + raise NotImplementedError @property def delete_specialist_pool( self, ) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError @property def update_specialist_pool( self, ) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation ]: - raise NotImplementedError() + raise NotImplementedError __all__ = ("SpecialistPoolServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 0e7c862e81..92cff5699c 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -15,15 +15,11 @@ # limitations under the License. # -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple +from typing import Callable, Dict from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore -from google.api_core import gapic_v1 # type: ignore -from google import auth # type: ignore from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -31,7 +27,7 @@ from google.cloud.aiplatform_v1beta1.types import specialist_pool_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .base import SpecialistPoolServiceTransport class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): @@ -52,21 +48,12 @@ class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ - _stubs: Dict[str, Callable] - def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Sequence[str] = None, - channel: grpc.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + channel: grpc.Channel = None ) -> None: """Instantiate the transport. @@ -78,119 +65,28 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ + # Sanity check: Ensure that channel and credentials are not both + # provided. if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - + # Run the base constructor. + super().__init__(host=host, credentials=credentials) self._stubs = {} # type: Dict[str, Callable] - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) + # If a channel was explicitly provided, set it. + if channel: + self._grpc_channel = channel @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - credentials_file: str = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, + **kwargs ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -200,31 +96,13 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. """ - scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, + host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs ) @property @@ -234,6 +112,13 @@ def grpc_channel(self) -> grpc.Channel: This property caches on the instance; repeated calls return the same channel. """ + # Sanity check: Only create a new channel if we do not already + # have one. + if not hasattr(self, "_grpc_channel"): + self._grpc_channel = self.create_channel( + self._host, credentials=self._credentials, + ) + # Return the channel from cache. return self._grpc_channel diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py deleted file mode 100644 index e2763c647f..0000000000 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,403 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple - -from google.api_core import gapic_v1 # type: ignore -from google.api_core import grpc_helpers_async # type: ignore -from google.api_core import operations_v1 # type: ignore -from google import auth # type: ignore -from google.auth import credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.cloud.aiplatform_v1beta1.types import specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool_service -from google.longrunning import operations_pb2 as operations # type: ignore - -from .base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import SpecialistPoolServiceGrpcTransport - - -class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): - """gRPC AsyncIO backend transport for SpecialistPoolService. - - A service for creating and managing Customer SpecialistPools. - When customers start Data Labeling jobs, they can reuse/create - Specialist Pools to bring their own Specialists to label the - data. Customers can add/remove Managers for the Specialist Pool - on Cloud console, then Managers will get email notifications to - manage Specialists and tasks on CrowdCompute console. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel( - cls, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs, - ) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - address (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - scopes = scopes or cls.AUTH_SCOPES - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - **kwargs, - ) - - def __init__( - self, - *, - host: str = "aiplatform.googleapis.com", - credentials: credentials.Credentials = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: aio.Channel = None, - api_mtls_endpoint: str = None, - client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, - quota_project_id=None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or applicatin default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - if channel: - # Sanity check: Ensure that channel and credentials are not both - # provided. - credentials = False - - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - - host = ( - api_mtls_endpoint - if ":" in api_mtls_endpoint - else api_mtls_endpoint + ":443" - ) - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - ssl_credentials = SslCredentials().ssl_credentials - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - else: - host = host if ":" in host else host + ":443" - - if credentials is None: - credentials, _ = auth.default( - scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id - ) - - # create a new channel. The provided one is ignored. - self._grpc_channel = type(self).create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - ) - - # Run the base constructor. - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes or self.AUTH_SCOPES, - quota_project_id=quota_project_id, - client_info=client_info, - ) - - self._stubs = {} - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def operations_client(self) -> operations_v1.OperationsAsyncClient: - """Create the client designed to process long-running operations. - - This property caches on the instance; repeated calls return the same - client. - """ - # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( - self.grpc_channel - ) - - # Return the client from cache. - return self.__dict__["operations_client"] - - @property - def create_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: - r"""Return a callable for the create specialist pool method over gRPC. - - Creates a SpecialistPool. - - Returns: - Callable[[~.CreateSpecialistPoolRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "create_specialist_pool" not in self._stubs: - self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", - request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["create_specialist_pool"] - - @property - def get_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.GetSpecialistPoolRequest], - Awaitable[specialist_pool.SpecialistPool], - ]: - r"""Return a callable for the get specialist pool method over gRPC. - - Gets a SpecialistPool. - - Returns: - Callable[[~.GetSpecialistPoolRequest], - Awaitable[~.SpecialistPool]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "get_specialist_pool" not in self._stubs: - self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", - request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, - response_deserializer=specialist_pool.SpecialistPool.deserialize, - ) - return self._stubs["get_specialist_pool"] - - @property - def list_specialist_pools( - self, - ) -> Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], - ]: - r"""Return a callable for the list specialist pools method over gRPC. - - Lists SpecialistPools in a Location. - - Returns: - Callable[[~.ListSpecialistPoolsRequest], - Awaitable[~.ListSpecialistPoolsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "list_specialist_pools" not in self._stubs: - self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", - request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, - response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, - ) - return self._stubs["list_specialist_pools"] - - @property - def delete_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: - r"""Return a callable for the delete specialist pool method over gRPC. - - Deletes a SpecialistPool as well as all Specialists - in the pool. - - Returns: - Callable[[~.DeleteSpecialistPoolRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "delete_specialist_pool" not in self._stubs: - self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", - request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["delete_specialist_pool"] - - @property - def update_specialist_pool( - self, - ) -> Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], - Awaitable[operations.Operation], - ]: - r"""Return a callable for the update specialist pool method over gRPC. - - Updates a SpecialistPool. - - Returns: - Callable[[~.UpdateSpecialistPoolRequest], - Awaitable[~.Operation]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if "update_specialist_pool" not in self._stubs: - self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( - "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", - request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, - response_deserializer=operations.Operation.FromString, - ) - return self._stubs["update_specialist_pool"] - - -__all__ = ("SpecialistPoolServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index e2995529d3..93508415dc 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -15,10 +15,7 @@ # limitations under the License. # -from .user_action_reference import UserActionReference -from .annotation import Annotation from .annotation_spec import AnnotationSpec -from .completion_stats import CompletionStats from .io import ( GcsSource, GcsDestination, @@ -26,6 +23,14 @@ BigQueryDestination, ContainerRegistryDestination, ) +from .dataset import ( + Dataset, + ImportDataConfig, + ExportDataConfig, +) +from .manual_batch_tuning_parameters import ManualBatchTuningParameters +from .completion_stats import CompletionStats +from .model_evaluation_slice import ModelEvaluationSlice from .machine_resources import ( MachineSpec, DedicatedResources, @@ -33,35 +38,8 @@ BatchDedicatedResources, ResourcesConsumed, ) -from .manual_batch_tuning_parameters import ManualBatchTuningParameters -from .batch_prediction_job import BatchPredictionJob -from .env_var import EnvVar -from .custom_job import ( - CustomJob, - CustomJobSpec, - WorkerPoolSpec, - ContainerSpec, - PythonPackageSpec, - Scheduling, -) -from .data_item import DataItem -from .specialist_pool import SpecialistPool -from .data_labeling_job import ( - DataLabelingJob, - ActiveLearningConfig, - SampleConfig, - TrainingConfig, -) -from .dataset import ( - Dataset, - ImportDataConfig, - ExportDataConfig, -) -from .operation import ( - GenericOperationMetadata, - DeleteOperationMetadata, -) from .deployed_model_ref import DeployedModelRef +from .env_var import EnvVar from .explanation_metadata import ExplanationMetadata from .explanation import ( Explanation, @@ -85,44 +63,22 @@ PredefinedSplit, TimestampSplit, ) -from .dataset_service import ( - CreateDatasetRequest, - CreateDatasetOperationMetadata, - GetDatasetRequest, - UpdateDatasetRequest, - ListDatasetsRequest, - ListDatasetsResponse, - DeleteDatasetRequest, - ImportDataRequest, - ImportDataResponse, - ImportDataOperationMetadata, - ExportDataRequest, - ExportDataResponse, - ExportDataOperationMetadata, - ListDataItemsRequest, - ListDataItemsResponse, - GetAnnotationSpecRequest, - ListAnnotationsRequest, - ListAnnotationsResponse, -) -from .endpoint import ( - Endpoint, - DeployedModel, +from .model_evaluation import ModelEvaluation +from .batch_prediction_job import BatchPredictionJob +from .custom_job import ( + CustomJob, + CustomJobSpec, + WorkerPoolSpec, + ContainerSpec, + PythonPackageSpec, + Scheduling, ) -from .endpoint_service import ( - CreateEndpointRequest, - CreateEndpointOperationMetadata, - GetEndpointRequest, - ListEndpointsRequest, - ListEndpointsResponse, - UpdateEndpointRequest, - DeleteEndpointRequest, - DeployModelRequest, - DeployModelResponse, - DeployModelOperationMetadata, - UndeployModelRequest, - UndeployModelResponse, - UndeployModelOperationMetadata, +from .specialist_pool import SpecialistPool +from .data_labeling_job import ( + DataLabelingJob, + ActiveLearningConfig, + SampleConfig, + TrainingConfig, ) from .study import ( Trial, @@ -156,8 +112,55 @@ DeleteBatchPredictionJobRequest, CancelBatchPredictionJobRequest, ) -from .model_evaluation import ModelEvaluation -from .model_evaluation_slice import ModelEvaluationSlice +from .user_action_reference import UserActionReference +from .annotation import Annotation +from .operation import ( + GenericOperationMetadata, + DeleteOperationMetadata, +) +from .endpoint import ( + Endpoint, + DeployedModel, +) +from .prediction_service import ( + PredictRequest, + PredictResponse, + ExplainRequest, + ExplainResponse, +) +from .endpoint_service import ( + CreateEndpointRequest, + CreateEndpointOperationMetadata, + GetEndpointRequest, + ListEndpointsRequest, + ListEndpointsResponse, + UpdateEndpointRequest, + DeleteEndpointRequest, + DeployModelRequest, + DeployModelResponse, + DeployModelOperationMetadata, + UndeployModelRequest, + UndeployModelResponse, + UndeployModelOperationMetadata, +) +from .pipeline_service import ( + CreateTrainingPipelineRequest, + GetTrainingPipelineRequest, + ListTrainingPipelinesRequest, + ListTrainingPipelinesResponse, + DeleteTrainingPipelineRequest, + CancelTrainingPipelineRequest, +) +from .specialist_pool_service import ( + CreateSpecialistPoolRequest, + CreateSpecialistPoolOperationMetadata, + GetSpecialistPoolRequest, + ListSpecialistPoolsRequest, + ListSpecialistPoolsResponse, + DeleteSpecialistPoolRequest, + UpdateSpecialistPoolRequest, + UpdateSpecialistPoolOperationMetadata, +) from .model_service import ( UploadModelRequest, UploadModelOperationMetadata, @@ -177,68 +180,49 @@ ListModelEvaluationSlicesRequest, ListModelEvaluationSlicesResponse, ) -from .pipeline_service import ( - CreateTrainingPipelineRequest, - GetTrainingPipelineRequest, - ListTrainingPipelinesRequest, - ListTrainingPipelinesResponse, - DeleteTrainingPipelineRequest, - CancelTrainingPipelineRequest, -) -from .prediction_service import ( - PredictRequest, - PredictResponse, - ExplainRequest, - ExplainResponse, -) -from .specialist_pool_service import ( - CreateSpecialistPoolRequest, - CreateSpecialistPoolOperationMetadata, - GetSpecialistPoolRequest, - ListSpecialistPoolsRequest, - ListSpecialistPoolsResponse, - DeleteSpecialistPoolRequest, - UpdateSpecialistPoolRequest, - UpdateSpecialistPoolOperationMetadata, +from .data_item import DataItem +from .dataset_service import ( + CreateDatasetRequest, + CreateDatasetOperationMetadata, + GetDatasetRequest, + UpdateDatasetRequest, + ListDatasetsRequest, + ListDatasetsResponse, + DeleteDatasetRequest, + ImportDataRequest, + ImportDataResponse, + ImportDataOperationMetadata, + ExportDataRequest, + ExportDataResponse, + ExportDataOperationMetadata, + ListDataItemsRequest, + ListDataItemsResponse, + GetAnnotationSpecRequest, + ListAnnotationsRequest, + ListAnnotationsResponse, ) __all__ = ( - "UserActionReference", - "Annotation", "AnnotationSpec", - "CompletionStats", "GcsSource", "GcsDestination", "BigQuerySource", "BigQueryDestination", "ContainerRegistryDestination", + "Dataset", + "ImportDataConfig", + "ExportDataConfig", + "ManualBatchTuningParameters", + "CompletionStats", + "ModelEvaluationSlice", "MachineSpec", "DedicatedResources", "AutomaticResources", "BatchDedicatedResources", "ResourcesConsumed", - "ManualBatchTuningParameters", - "BatchPredictionJob", - "EnvVar", - "CustomJob", - "CustomJobSpec", - "WorkerPoolSpec", - "ContainerSpec", - "PythonPackageSpec", - "Scheduling", - "DataItem", - "SpecialistPool", - "DataLabelingJob", - "ActiveLearningConfig", - "SampleConfig", - "TrainingConfig", - "Dataset", - "ImportDataConfig", - "ExportDataConfig", - "GenericOperationMetadata", - "DeleteOperationMetadata", "DeployedModelRef", + "EnvVar", "ExplanationMetadata", "Explanation", "ModelExplanation", @@ -256,39 +240,19 @@ "FilterSplit", "PredefinedSplit", "TimestampSplit", - "CreateDatasetRequest", - "CreateDatasetOperationMetadata", - "GetDatasetRequest", - "UpdateDatasetRequest", - "ListDatasetsRequest", - "ListDatasetsResponse", - "DeleteDatasetRequest", - "ImportDataRequest", - "ImportDataResponse", - "ImportDataOperationMetadata", - "ExportDataRequest", - "ExportDataResponse", - "ExportDataOperationMetadata", - "ListDataItemsRequest", - "ListDataItemsResponse", - "GetAnnotationSpecRequest", - "ListAnnotationsRequest", - "ListAnnotationsResponse", - "Endpoint", - "DeployedModel", - "CreateEndpointRequest", - "CreateEndpointOperationMetadata", - "GetEndpointRequest", - "ListEndpointsRequest", - "ListEndpointsResponse", - "UpdateEndpointRequest", - "DeleteEndpointRequest", - "DeployModelRequest", - "DeployModelResponse", - "DeployModelOperationMetadata", - "UndeployModelRequest", - "UndeployModelResponse", - "UndeployModelOperationMetadata", + "ModelEvaluation", + "BatchPredictionJob", + "CustomJob", + "CustomJobSpec", + "WorkerPoolSpec", + "ContainerSpec", + "PythonPackageSpec", + "Scheduling", + "SpecialistPool", + "DataLabelingJob", + "ActiveLearningConfig", + "SampleConfig", + "TrainingConfig", "Trial", "StudySpec", "Measurement", @@ -317,8 +281,43 @@ "ListBatchPredictionJobsResponse", "DeleteBatchPredictionJobRequest", "CancelBatchPredictionJobRequest", - "ModelEvaluation", - "ModelEvaluationSlice", + "UserActionReference", + "Annotation", + "GenericOperationMetadata", + "DeleteOperationMetadata", + "Endpoint", + "DeployedModel", + "PredictRequest", + "PredictResponse", + "ExplainRequest", + "ExplainResponse", + "CreateEndpointRequest", + "CreateEndpointOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UpdateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelRequest", + "DeployModelResponse", + "DeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UndeployModelOperationMetadata", + "CreateTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "DeleteTrainingPipelineRequest", + "CancelTrainingPipelineRequest", + "CreateSpecialistPoolRequest", + "CreateSpecialistPoolOperationMetadata", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "DeleteSpecialistPoolRequest", + "UpdateSpecialistPoolRequest", + "UpdateSpecialistPoolOperationMetadata", "UploadModelRequest", "UploadModelOperationMetadata", "UploadModelResponse", @@ -336,22 +335,23 @@ "GetModelEvaluationSliceRequest", "ListModelEvaluationSlicesRequest", "ListModelEvaluationSlicesResponse", - "CreateTrainingPipelineRequest", - "GetTrainingPipelineRequest", - "ListTrainingPipelinesRequest", - "ListTrainingPipelinesResponse", - "DeleteTrainingPipelineRequest", - "CancelTrainingPipelineRequest", - "PredictRequest", - "PredictResponse", - "ExplainRequest", - "ExplainResponse", - "CreateSpecialistPoolRequest", - "CreateSpecialistPoolOperationMetadata", - "GetSpecialistPoolRequest", - "ListSpecialistPoolsRequest", - "ListSpecialistPoolsResponse", - "DeleteSpecialistPoolRequest", - "UpdateSpecialistPoolRequest", - "UpdateSpecialistPoolOperationMetadata", + "DataItem", + "CreateDatasetRequest", + "CreateDatasetOperationMetadata", + "GetDatasetRequest", + "UpdateDatasetRequest", + "ListDatasetsRequest", + "ListDatasetsResponse", + "DeleteDatasetRequest", + "ImportDataRequest", + "ImportDataResponse", + "ImportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportDataOperationMetadata", + "ListDataItemsRequest", + "ListDataItemsResponse", + "GetAnnotationSpecRequest", + "ListAnnotationsRequest", + "ListAnnotationsResponse", ) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index b56d1f55e5..e6eca04509 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -88,21 +88,14 @@ class Annotation(proto.Message): """ name = proto.Field(proto.STRING, number=1) - payload_schema_uri = proto.Field(proto.STRING, number=2) - payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - etag = proto.Field(proto.STRING, number=8) - annotation_source = proto.Field( proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, ) - labels = proto.MapField(proto.STRING, proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index a5a4b3d489..4719fb12ce 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -52,13 +52,9 @@ class AnnotationSpec(proto.Message): """ name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - etag = proto.Field(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index 08e8763aaf..bc67ec8796 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -198,14 +198,10 @@ class InputConfig(proto.Message): [supported_input_storage_formats][google.cloud.aiplatform.v1beta1.Model.supported_input_storage_formats]. """ - gcs_source = proto.Field( - proto.MESSAGE, number=2, oneof="source", message=io.GcsSource, - ) - + gcs_source = proto.Field(proto.MESSAGE, number=2, message=io.GcsSource,) bigquery_source = proto.Field( - proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, + proto.MESSAGE, number=3, message=io.BigQuerySource, ) - instances_format = proto.Field(proto.STRING, number=1) class OutputConfig(proto.Message): @@ -275,16 +271,11 @@ class OutputConfig(proto.Message): """ gcs_destination = proto.Field( - proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, + proto.MESSAGE, number=2, message=io.GcsDestination, ) - bigquery_destination = proto.Field( - proto.MESSAGE, - number=3, - oneof="destination", - message=io.BigQueryDestination, + proto.MESSAGE, number=3, message=io.BigQueryDestination, ) - predictions_format = proto.Field(proto.STRING, number=1) class OutputInfo(proto.Message): @@ -302,64 +293,40 @@ class OutputInfo(proto.Message): prediction output is written. """ - gcs_output_directory = proto.Field( - proto.STRING, number=1, oneof="output_location" - ) - - bigquery_output_dataset = proto.Field( - proto.STRING, number=2, oneof="output_location" - ) + gcs_output_directory = proto.Field(proto.STRING, number=1) + bigquery_output_dataset = proto.Field(proto.STRING, number=2) name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - model = proto.Field(proto.STRING, number=3) - input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) - model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) - dedicated_resources = proto.Field( proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, ) - manual_batch_tuning_parameters = proto.Field( proto.MESSAGE, number=8, message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, ) - generate_explanation = proto.Field(proto.BOOL, number=23) - output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) - state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) - partial_failures = proto.RepeatedField( proto.MESSAGE, number=12, message=status.Status, ) - resources_consumed = proto.Field( proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, ) - completion_stats = proto.Field( proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, ) - create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) - labels = proto.MapField(proto.STRING, proto.STRING, number=19) diff --git a/google/cloud/aiplatform_v1beta1/types/completion_stats.py b/google/cloud/aiplatform_v1beta1/types/completion_stats.py index 165be59634..22f1fa7975 100644 --- a/google/cloud/aiplatform_v1beta1/types/completion_stats.py +++ b/google/cloud/aiplatform_v1beta1/types/completion_stats.py @@ -46,9 +46,7 @@ class CompletionStats(proto.Message): """ successful_count = proto.Field(proto.INT64, number=1) - failed_count = proto.Field(proto.INT64, number=2) - incomplete_count = proto.Field(proto.INT64, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 0eefb760e7..870f12af71 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -86,23 +86,14 @@ class CustomJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) - state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) - labels = proto.MapField(proto.STRING, proto.STRING, number=11) @@ -151,9 +142,7 @@ class CustomJobSpec(proto.Message): worker_pool_specs = proto.RepeatedField( proto.MESSAGE, number=1, message="WorkerPoolSpec", ) - scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) - base_output_directory = proto.Field( proto.MESSAGE, number=6, message=io.GcsDestination, ) @@ -175,18 +164,13 @@ class WorkerPoolSpec(proto.Message): use for this worker pool. """ - container_spec = proto.Field( - proto.MESSAGE, number=6, oneof="task", message="ContainerSpec", - ) - + container_spec = proto.Field(proto.MESSAGE, number=6, message="ContainerSpec",) python_package_spec = proto.Field( - proto.MESSAGE, number=7, oneof="task", message="PythonPackageSpec", + proto.MESSAGE, number=7, message="PythonPackageSpec", ) - machine_spec = proto.Field( proto.MESSAGE, number=1, message=machine_resources.MachineSpec, ) - replica_count = proto.Field(proto.INT64, number=2) @@ -208,9 +192,7 @@ class ContainerSpec(proto.Message): """ image_uri = proto.Field(proto.STRING, number=1) - command = proto.RepeatedField(proto.STRING, number=2) - args = proto.RepeatedField(proto.STRING, number=3) @@ -239,11 +221,8 @@ class PythonPackageSpec(proto.Message): """ executor_image_uri = proto.Field(proto.STRING, number=1) - package_uris = proto.RepeatedField(proto.STRING, number=2) - python_module = proto.Field(proto.STRING, number=3) - args = proto.RepeatedField(proto.STRING, number=4) @@ -263,7 +242,6 @@ class Scheduling(proto.Message): """ timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) - restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index e43a944d94..418e8cc739 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -69,15 +69,10 @@ class DataItem(proto.Message): """ name = proto.Field(proto.STRING, number=1) - create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - labels = proto.MapField(proto.STRING, proto.STRING, number=3) - payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) - etag = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index 2e27dcaccf..9639bd070f 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -128,35 +128,20 @@ class DataLabelingJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - datasets = proto.RepeatedField(proto.STRING, number=3) - annotation_labels = proto.MapField(proto.STRING, proto.STRING, number=12) - labeler_count = proto.Field(proto.INT32, number=4) - instruction_uri = proto.Field(proto.STRING, number=5) - inputs_schema_uri = proto.Field(proto.STRING, number=6) - inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) - state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) - labeling_progress = proto.Field(proto.INT32, number=13) - current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) - create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) - labels = proto.MapField(proto.STRING, proto.STRING, number=11) - specialist_pools = proto.RepeatedField(proto.STRING, number=16) - active_learning_config = proto.Field( proto.MESSAGE, number=21, message="ActiveLearningConfig", ) @@ -187,16 +172,9 @@ class ActiveLearningConfig(proto.Message): select DataItems. """ - max_data_item_count = proto.Field( - proto.INT64, number=1, oneof="human_labeling_budget" - ) - - max_data_item_percentage = proto.Field( - proto.INT32, number=2, oneof="human_labeling_budget" - ) - + max_data_item_count = proto.Field(proto.INT64, number=1) + max_data_item_percentage = proto.Field(proto.INT32, number=2) sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) - training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) @@ -226,14 +204,8 @@ class SampleStrategy(proto.Enum): SAMPLE_STRATEGY_UNSPECIFIED = 0 UNCERTAINTY = 1 - initial_batch_sample_percentage = proto.Field( - proto.INT32, number=1, oneof="initial_batch_sample_size" - ) - - following_batch_sample_percentage = proto.Field( - proto.INT32, number=3, oneof="following_batch_sample_size" - ) - + initial_batch_sample_percentage = proto.Field(proto.INT32, number=1) + following_batch_sample_percentage = proto.Field(proto.INT32, number=3) sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 6acad441de..5f30adff8a 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -83,19 +83,12 @@ class Dataset(proto.Message): """ name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - metadata_schema_uri = proto.Field(proto.STRING, number=3) - metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) - etag = proto.Field(proto.STRING, number=6) - labels = proto.MapField(proto.STRING, proto.STRING, number=7) @@ -131,12 +124,8 @@ class ImportDataConfig(proto.Message): Object `__. """ - gcs_source = proto.Field( - proto.MESSAGE, number=1, oneof="source", message=io.GcsSource, - ) - + gcs_source = proto.Field(proto.MESSAGE, number=1, message=io.GcsSource,) data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) - import_schema_uri = proto.Field(proto.STRING, number=4) @@ -164,10 +153,7 @@ class ExportDataConfig(proto.Message): [ListAnnotations][google.cloud.aiplatform.v1beta1.DatasetService.ListAnnotations]. """ - gcs_destination = proto.Field( - proto.MESSAGE, number=1, oneof="destination", message=io.GcsDestination, - ) - + gcs_destination = proto.Field(proto.MESSAGE, number=1, message=io.GcsDestination,) annotations_filter = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 9f9e42d15e..2d4025b535 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -64,7 +64,6 @@ class CreateDatasetRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) @@ -94,7 +93,6 @@ class GetDatasetRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -120,7 +118,6 @@ class UpdateDatasetRequest(proto.Message): """ dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -151,15 +148,10 @@ class ListDatasetsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) - order_by = proto.Field(proto.STRING, number=6) @@ -182,7 +174,6 @@ def raw_page(self): datasets = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_dataset.Dataset, ) - next_page_token = proto.Field(proto.STRING, number=2) @@ -215,7 +206,6 @@ class ImportDataRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - import_configs = proto.RepeatedField( proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, ) @@ -254,7 +244,6 @@ class ExportDataRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - export_config = proto.Field( proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, ) @@ -289,7 +278,6 @@ class ExportDataOperationMetadata(proto.Message): generic_metadata = proto.Field( proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - gcs_output_directory = proto.Field(proto.STRING, number=2) @@ -317,15 +305,10 @@ class ListDataItemsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) - order_by = proto.Field(proto.STRING, number=6) @@ -348,7 +331,6 @@ def raw_page(self): data_items = proto.RepeatedField( proto.MESSAGE, number=1, message=data_item.DataItem, ) - next_page_token = proto.Field(proto.STRING, number=2) @@ -366,7 +348,6 @@ class GetAnnotationSpecRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -395,15 +376,10 @@ class ListAnnotationsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) - order_by = proto.Field(proto.STRING, number=6) @@ -426,7 +402,6 @@ def raw_page(self): annotations = proto.RepeatedField( proto.MESSAGE, number=1, message=annotation.Annotation, ) - next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py index b0ec7010a2..6a7f18850f 100644 --- a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py +++ b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py @@ -35,7 +35,6 @@ class DeployedModelRef(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) - deployed_model_id = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 2ff4464afb..0f9eac501c 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -82,23 +82,15 @@ class Endpoint(proto.Message): """ name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - description = proto.Field(proto.STRING, number=3) - deployed_models = proto.RepeatedField( proto.MESSAGE, number=4, message="DeployedModel", ) - traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) - etag = proto.Field(proto.STRING, number=6) - labels = proto.MapField(proto.STRING, proto.STRING, number=7) - create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) @@ -167,33 +159,19 @@ class DeployedModel(proto.Message): """ dedicated_resources = proto.Field( - proto.MESSAGE, - number=7, - oneof="prediction_resources", - message=machine_resources.DedicatedResources, + proto.MESSAGE, number=7, message=machine_resources.DedicatedResources, ) - automatic_resources = proto.Field( - proto.MESSAGE, - number=8, - oneof="prediction_resources", - message=machine_resources.AutomaticResources, + proto.MESSAGE, number=8, message=machine_resources.AutomaticResources, ) - id = proto.Field(proto.STRING, number=1) - model = proto.Field(proto.STRING, number=2) - display_name = proto.Field(proto.STRING, number=3) - create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) - explanation_spec = proto.Field( proto.MESSAGE, number=9, message=explanation.ExplanationSpec, ) - enable_container_logging = proto.Field(proto.BOOL, number=12) - enable_access_logging = proto.Field(proto.BOOL, number=13) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index 79ddf14a04..616cdf0eba 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -57,7 +57,6 @@ class CreateEndpointRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) @@ -136,13 +135,9 @@ class ListEndpointsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -166,7 +161,6 @@ def raw_page(self): endpoints = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, ) - next_page_token = proto.Field(proto.STRING, number=2) @@ -184,7 +178,6 @@ class UpdateEndpointRequest(proto.Message): """ endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -237,11 +230,9 @@ class DeployModelRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) - deployed_model = proto.Field( proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, ) - traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -298,9 +289,7 @@ class UndeployModelRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) - deployed_model_id = proto.Field(proto.STRING, number=2) - traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/env_var.py b/google/cloud/aiplatform_v1beta1/types/env_var.py index 207e8275cd..0c22313d63 100644 --- a/google/cloud/aiplatform_v1beta1/types/env_var.py +++ b/google/cloud/aiplatform_v1beta1/types/env_var.py @@ -43,7 +43,6 @@ class EnvVar(proto.Message): """ name = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index 1b4296b022..6abc83ce3a 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -203,15 +203,10 @@ class Attribution(proto.Message): """ baseline_output_value = proto.Field(proto.DOUBLE, number=1) - instance_output_value = proto.Field(proto.DOUBLE, number=2) - feature_attributions = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - output_index = proto.RepeatedField(proto.INT32, number=4) - output_display_name = proto.Field(proto.STRING, number=5) - approximation_error = proto.Field(proto.DOUBLE, number=6) @@ -229,7 +224,6 @@ class ExplanationSpec(proto.Message): """ parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) - metadata = proto.Field( proto.MESSAGE, number=2, message=explanation_metadata.ExplanationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index 257b9a9f99..12c7c4bc6f 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -119,21 +119,16 @@ class OutputMetadata(proto.Message): """ index_display_name_mapping = proto.Field( - proto.MESSAGE, number=1, oneof="display_name_mapping", message=struct.Value, - ) - - display_name_mapping_key = proto.Field( - proto.STRING, number=2, oneof="display_name_mapping" + proto.MESSAGE, number=1, message=struct.Value, ) + display_name_mapping_key = proto.Field(proto.STRING, number=2) inputs = proto.MapField( proto.STRING, proto.MESSAGE, number=1, message=InputMetadata, ) - outputs = proto.MapField( proto.STRING, proto.MESSAGE, number=2, message=OutputMetadata, ) - feature_attributions_schema_uri = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py index 78af635e79..171e37ad09 100644 --- a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py @@ -96,35 +96,21 @@ class HyperparameterTuningJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) - max_trial_count = proto.Field(proto.INT32, number=5) - parallel_trial_count = proto.Field(proto.INT32, number=6) - max_failed_trial_count = proto.Field(proto.INT32, number=7) - trial_job_spec = proto.Field( proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, ) - trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) - state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) - create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) - labels = proto.MapField(proto.STRING, proto.STRING, number=16) diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index 15cb490682..302d909b08 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -76,7 +76,6 @@ class CreateCustomJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) @@ -133,13 +132,9 @@ class ListCustomJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -163,7 +158,6 @@ def raw_page(self): custom_jobs = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, ) - next_page_token = proto.Field(proto.STRING, number=2) @@ -207,7 +201,6 @@ class CreateDataLabelingJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field( proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, ) @@ -268,15 +261,10 @@ class ListDataLabelingJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) - order_by = proto.Field(proto.STRING, number=6) @@ -299,7 +287,6 @@ def raw_page(self): data_labeling_jobs = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, ) - next_page_token = proto.Field(proto.STRING, number=2) @@ -347,7 +334,6 @@ class CreateHyperparameterTuningJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - hyperparameter_tuning_job = proto.Field( proto.MESSAGE, number=2, @@ -410,13 +396,9 @@ class ListHyperparameterTuningJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -444,7 +426,6 @@ def raw_page(self): number=1, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) - next_page_token = proto.Field(proto.STRING, number=2) @@ -492,7 +473,6 @@ class CreateBatchPredictionJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - batch_prediction_job = proto.Field( proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, ) @@ -553,13 +533,9 @@ class ListBatchPredictionJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -584,7 +560,6 @@ def raw_page(self): batch_prediction_jobs = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, ) - next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index a094e3896e..ff73b6db72 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -89,11 +89,9 @@ class MachineSpec(proto.Message): """ machine_type = proto.Field(proto.STRING, number=1) - accelerator_type = proto.Field( proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, ) - accelerator_count = proto.Field(proto.INT32, number=3) @@ -131,9 +129,7 @@ class DedicatedResources(proto.Message): """ machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) - min_replica_count = proto.Field(proto.INT32, number=2) - max_replica_count = proto.Field(proto.INT32, number=3) @@ -170,7 +166,6 @@ class AutomaticResources(proto.Message): """ min_replica_count = proto.Field(proto.INT32, number=1) - max_replica_count = proto.Field(proto.INT32, number=2) @@ -195,9 +190,7 @@ class BatchDedicatedResources(proto.Message): """ machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) - starting_replica_count = proto.Field(proto.INT32, number=2) - max_replica_count = proto.Field(proto.INT32, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index 77f2901b0d..e2a8d2995b 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -272,55 +272,36 @@ class ExportableContent(proto.Enum): IMAGE = 2 id = proto.Field(proto.STRING, number=1) - exportable_contents = proto.RepeatedField( proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", ) name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - description = proto.Field(proto.STRING, number=3) - predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) - metadata_schema_uri = proto.Field(proto.STRING, number=5) - metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - supported_export_formats = proto.RepeatedField( proto.MESSAGE, number=20, message=ExportFormat, ) - training_pipeline = proto.Field(proto.STRING, number=7) - container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) - artifact_uri = proto.Field(proto.STRING, number=26) - supported_deployment_resources_types = proto.RepeatedField( proto.ENUM, number=10, enum=DeploymentResourcesType, ) - supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) - supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) - create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - deployed_models = proto.RepeatedField( proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, ) - explanation_spec = proto.Field( proto.MESSAGE, number=23, message=explanation.ExplanationSpec, ) - etag = proto.Field(proto.STRING, number=16) - labels = proto.MapField(proto.STRING, proto.STRING, number=17) @@ -382,9 +363,7 @@ class PredictSchemata(proto.Message): """ instance_schema_uri = proto.Field(proto.STRING, number=1) - parameters_schema_uri = proto.Field(proto.STRING, number=2) - prediction_schema_uri = proto.Field(proto.STRING, number=3) @@ -448,17 +427,11 @@ class ModelContainerSpec(proto.Message): """ image_uri = proto.Field(proto.STRING, number=1) - command = proto.RepeatedField(proto.STRING, number=2) - args = proto.RepeatedField(proto.STRING, number=3) - env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) - ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) - predict_route = proto.Field(proto.STRING, number=6) - health_route = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py index 839ca3a191..5e54055a9e 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py @@ -68,15 +68,10 @@ class ModelEvaluation(proto.Message): """ name = proto.Field(proto.STRING, number=1) - metrics_schema_uri = proto.Field(proto.STRING, number=2) - metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) - slice_dimensions = proto.RepeatedField(proto.STRING, number=5) - model_explanation = proto.Field( proto.MESSAGE, number=8, message=explanation.ModelExplanation, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py index 59e033047e..8f8125d51d 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py @@ -36,7 +36,7 @@ class ModelEvaluationSlice(proto.Message): name (str): Output only. The resource name of the ModelEvaluationSlice. - slice_ (~.model_evaluation_slice.ModelEvaluationSlice.Slice): + slice (~.model_evaluation_slice.ModelEvaluationSlice.Slice): Output only. The slice of the test data that is used to evaluate the Model. metrics_schema_uri (str): @@ -74,17 +74,12 @@ class Slice(proto.Message): """ dimension = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.STRING, number=2) name = proto.Field(proto.STRING, number=1) - - slice_ = proto.Field(proto.MESSAGE, number=2, message=Slice,) - + slice = proto.Field(proto.MESSAGE, number=2, message=Slice,) metrics_schema_uri = proto.Field(proto.STRING, number=3) - metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) - create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index 8f581aadc3..4a5978045e 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -64,7 +64,6 @@ class UploadModelRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) @@ -134,13 +133,9 @@ class ListModelsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -162,7 +157,6 @@ def raw_page(self): return self models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) - next_page_token = proto.Field(proto.STRING, number=2) @@ -183,7 +177,6 @@ class UpdateModelRequest(proto.Message): """ model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -244,17 +237,14 @@ class OutputConfig(proto.Message): """ export_format_id = proto.Field(proto.STRING, number=1) - artifact_destination = proto.Field( proto.MESSAGE, number=3, message=io.GcsDestination, ) - image_destination = proto.Field( proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, ) name = proto.Field(proto.STRING, number=1) - output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) @@ -289,13 +279,11 @@ class OutputInfo(proto.Message): """ artifact_output_uri = proto.Field(proto.STRING, number=2) - image_output_uri = proto.Field(proto.STRING, number=3) generic_metadata = proto.Field( proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) - output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) @@ -344,13 +332,9 @@ class ListModelEvaluationsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -375,7 +359,6 @@ def raw_page(self): model_evaluations = proto.RepeatedField( proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, ) - next_page_token = proto.Field(proto.STRING, number=2) @@ -421,13 +404,9 @@ class ListModelEvaluationSlicesRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -452,7 +431,6 @@ def raw_page(self): model_evaluation_slices = proto.RepeatedField( proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, ) - next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/operation.py b/google/cloud/aiplatform_v1beta1/types/operation.py index 12b2150c35..3451fb9c8c 100644 --- a/google/cloud/aiplatform_v1beta1/types/operation.py +++ b/google/cloud/aiplatform_v1beta1/types/operation.py @@ -49,9 +49,7 @@ class GenericOperationMetadata(proto.Message): partial_failures = proto.RepeatedField( proto.MESSAGE, number=1, message=status.Status, ) - create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index 33447d232f..089d6185c8 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -51,7 +51,6 @@ class CreateTrainingPipelineRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - training_pipeline = proto.Field( proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, ) @@ -109,13 +108,9 @@ class ListTrainingPipelinesRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - filter = proto.Field(proto.STRING, number=2) - page_size = proto.Field(proto.INT32, number=3) - page_token = proto.Field(proto.STRING, number=4) - read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -140,7 +135,6 @@ def raw_page(self): training_pipelines = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, ) - next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index cd2301395e..efff997a32 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -64,9 +64,7 @@ class PredictRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) @@ -88,7 +86,6 @@ class PredictResponse(proto.Message): """ predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) - deployed_model_id = proto.Field(proto.STRING, number=2) @@ -127,11 +124,8 @@ class ExplainRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) - instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) - parameters = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) - deployed_model_id = proto.Field(proto.STRING, number=3) @@ -155,7 +149,6 @@ class ExplainResponse(proto.Message): explanations = proto.RepeatedField( proto.MESSAGE, number=1, message=explanation.Explanation, ) - deployed_model_id = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py index 4ac8c6a709..21ab5f9c47 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py @@ -55,13 +55,9 @@ class SpecialistPool(proto.Message): """ name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - specialist_managers_count = proto.Field(proto.INT32, number=3) - specialist_manager_emails = proto.RepeatedField(proto.STRING, number=4) - pending_data_labeling_jobs = proto.RepeatedField(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py index 9cb5de8edd..fad6429b9a 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py @@ -52,7 +52,6 @@ class CreateSpecialistPoolRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - specialist_pool = proto.Field( proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, ) @@ -109,11 +108,8 @@ class ListSpecialistPoolsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) - page_size = proto.Field(proto.INT32, number=2) - page_token = proto.Field(proto.STRING, number=3) - read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) @@ -136,7 +132,6 @@ def raw_page(self): specialist_pools = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) - next_page_token = proto.Field(proto.STRING, number=2) @@ -157,7 +152,6 @@ class DeleteSpecialistPoolRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) - force = proto.Field(proto.BOOL, number=2) @@ -177,7 +171,6 @@ class UpdateSpecialistPoolRequest(proto.Message): specialist_pool = proto.Field( proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) - update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -196,7 +189,6 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): """ specialist_pool = proto.Field(proto.STRING, number=1) - generic_metadata = proto.Field( proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 5d053e7162..55f20127ef 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -81,21 +81,14 @@ class Parameter(proto.Message): """ parameter_id = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) id = proto.Field(proto.STRING, number=2) - state = proto.Field(proto.ENUM, number=3, enum=State,) - parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) - final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) - start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) - custom_job = proto.Field(proto.STRING, number=11) @@ -137,7 +130,6 @@ class GoalType(proto.Enum): MINIMIZE = 2 metric_id = proto.Field(proto.STRING, number=1) - goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) class ParameterSpec(proto.Message): @@ -181,7 +173,6 @@ class DoubleValueSpec(proto.Message): """ min_value = proto.Field(proto.DOUBLE, number=1) - max_value = proto.Field(proto.DOUBLE, number=2) class IntegerValueSpec(proto.Message): @@ -197,7 +188,6 @@ class IntegerValueSpec(proto.Message): """ min_value = proto.Field(proto.INT64, number=1) - max_value = proto.Field(proto.INT64, number=2) class CategoricalValueSpec(proto.Message): @@ -226,43 +216,28 @@ class DiscreteValueSpec(proto.Message): values = proto.RepeatedField(proto.DOUBLE, number=1) double_value_spec = proto.Field( - proto.MESSAGE, - number=2, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.DoubleValueSpec", + proto.MESSAGE, number=2, message="StudySpec.ParameterSpec.DoubleValueSpec", ) - integer_value_spec = proto.Field( - proto.MESSAGE, - number=3, - oneof="parameter_value_spec", - message="StudySpec.ParameterSpec.IntegerValueSpec", + proto.MESSAGE, number=3, message="StudySpec.ParameterSpec.IntegerValueSpec", ) - categorical_value_spec = proto.Field( proto.MESSAGE, number=4, - oneof="parameter_value_spec", message="StudySpec.ParameterSpec.CategoricalValueSpec", ) - discrete_value_spec = proto.Field( proto.MESSAGE, number=5, - oneof="parameter_value_spec", message="StudySpec.ParameterSpec.DiscreteValueSpec", ) - parameter_id = proto.Field(proto.STRING, number=1) - scale_type = proto.Field( proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", ) metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) - parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) - algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) @@ -295,11 +270,9 @@ class Metric(proto.Message): """ metric_id = proto.Field(proto.STRING, number=1) - value = proto.Field(proto.DOUBLE, number=2) step_count = proto.Field(proto.INT64, number=2) - metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index 8a16001567..cd637cde9e 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -143,31 +143,18 @@ class TrainingPipeline(proto.Message): """ name = proto.Field(proto.STRING, number=1) - display_name = proto.Field(proto.STRING, number=2) - input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) - training_task_definition = proto.Field(proto.STRING, number=4) - training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) - training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) - model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) - state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) - error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) - create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) - start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) - end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) - update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) - labels = proto.MapField(proto.STRING, proto.STRING, number=15) @@ -253,30 +240,13 @@ class InputDataConfig(proto.Message): [annotation_schema_uri][google.cloud.aiplatform.v1beta1.InputDataConfig.annotation_schema_uri]. """ - fraction_split = proto.Field( - proto.MESSAGE, number=2, oneof="split", message="FractionSplit", - ) - - filter_split = proto.Field( - proto.MESSAGE, number=3, oneof="split", message="FilterSplit", - ) - - predefined_split = proto.Field( - proto.MESSAGE, number=4, oneof="split", message="PredefinedSplit", - ) - - timestamp_split = proto.Field( - proto.MESSAGE, number=5, oneof="split", message="TimestampSplit", - ) - - gcs_destination = proto.Field( - proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, - ) - + fraction_split = proto.Field(proto.MESSAGE, number=2, message="FractionSplit",) + filter_split = proto.Field(proto.MESSAGE, number=3, message="FilterSplit",) + predefined_split = proto.Field(proto.MESSAGE, number=4, message="PredefinedSplit",) + timestamp_split = proto.Field(proto.MESSAGE, number=5, message="TimestampSplit",) + gcs_destination = proto.Field(proto.MESSAGE, number=8, message=io.GcsDestination,) dataset_id = proto.Field(proto.STRING, number=1) - annotations_filter = proto.Field(proto.STRING, number=6) - annotation_schema_uri = proto.Field(proto.STRING, number=9) @@ -302,9 +272,7 @@ class FractionSplit(proto.Message): """ training_fraction = proto.Field(proto.DOUBLE, number=1) - validation_fraction = proto.Field(proto.DOUBLE, number=2) - test_fraction = proto.Field(proto.DOUBLE, number=3) @@ -347,9 +315,7 @@ class FilterSplit(proto.Message): """ training_filter = proto.Field(proto.STRING, number=1) - validation_filter = proto.Field(proto.STRING, number=2) - test_filter = proto.Field(proto.STRING, number=3) @@ -400,11 +366,8 @@ class TimestampSplit(proto.Message): """ training_fraction = proto.Field(proto.DOUBLE, number=1) - validation_fraction = proto.Field(proto.DOUBLE, number=2) - test_fraction = proto.Field(proto.DOUBLE, number=3) - key = proto.Field(proto.STRING, number=4) diff --git a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py index 710e4a6d16..ce868edc27 100644 --- a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py @@ -44,10 +44,8 @@ class UserActionReference(proto.Message): "/google.cloud.aiplatform.v1alpha1.DatasetService.CreateDataset". """ - operation = proto.Field(proto.STRING, number=1, oneof="reference") - - data_labeling_job = proto.Field(proto.STRING, number=2, oneof="reference") - + operation = proto.Field(proto.STRING, number=1) + data_labeling_job = proto.Field(proto.STRING, number=2) method = proto.Field(proto.STRING, number=3) diff --git a/mypy.ini b/mypy.ini index 4505b48543..f23e6b533a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,3 @@ [mypy] -python_version = 3.6 +python_version = 3.5 namespace_packages = True diff --git a/noxfile.py b/noxfile.py index 0e83aab566..b203d2e28c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -26,9 +26,9 @@ BLACK_VERSION = "black==19.10b0" BLACK_PATHS = ["docs", "google", "tests", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.8" -SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] -UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8"] +DEFAULT_PYTHON_VERSION = "" +SYSTEM_TEST_PYTHON_VERSIONS = [] +UNIT_TEST_PYTHON_VERSIONS = [] @nox.session(python=DEFAULT_PYTHON_VERSION) diff --git a/synth.metadata b/synth.metadata index a410fba17c..3a478068c2 100644 --- a/synth.metadata +++ b/synth.metadata @@ -3,22 +3,15 @@ { "git": { "name": ".", - "remote": "sso://devrel/cloud/libraries/python/python-aiplatform", - "sha": "5457df5ec6aed011a3ffe23db8e9c7b6fb74548f" + "remote": "https://989977d47c15cdc28bf193434047b8ad35f6c849@github.com/dizcology/python-aiplatform", + "sha": "0610b3259bf3239d27ccb7c0cb57eee905317922" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "f3c04883d6c43261ff13db1f52d03a283be06871" - } - }, - { - "git": { - "name": "synthtool", - "remote": "https://github.com/googleapis/synthtool.git", - "sha": "f3c04883d6c43261ff13db1f52d03a283be06871" + "sha": "916c10e8581804df2b48a0f0457d848f3faa582e" } } ], @@ -29,7 +22,7 @@ "apiName": "aiplatform", "apiVersion": "v1beta1", "language": "python", - "generator": "bazel" + "generator": "gapic-generator-python" } } ] diff --git a/synth.py b/synth.py index 862a121f8d..5aad5b1a77 100644 --- a/synth.py +++ b/synth.py @@ -20,20 +20,59 @@ import synthtool.gcp as gcp from synthtool.languages import python -gapic = gcp.GAPICBazel() +# Use the microgenerator for now since we want to pin the generator version. +# gapic = gcp.GAPICBazel() +gapic = gcp.GAPICMicrogenerator() + common = gcp.CommonTemplates() # ---------------------------------------------------------------------------- # Generate AI Platform GAPIC layer # ---------------------------------------------------------------------------- +# library = gapic.py_library( +# service="aiplatform", +# version="v1beta1", +# bazel_target="//google/cloud/aiplatform/v1beta1:aiplatform-v1beta1-py", +# ) library = gapic.py_library( - service="aiplatform", - version="v1beta1", - bazel_target="//google/cloud/aiplatform/v1beta1:aiplatform-v1beta1-py", + 'aiplatform', + 'v1beta1', + generator_version='0.20' +) + +s.move( + library, + excludes=[ + "setup.py", + "README.rst", + "docs/index.rst", + ] +) + +# ---------------------------------------------------------------------------- +# Patch the library +# ---------------------------------------------------------------------------- + +# https://github.com/googleapis/gapic-generator-python/issues/336 +s.replace( + '**/client.py', + ' operation.from_gapic', + ' ga_operation.from_gapic' ) -s.move(library, excludes=["setup.py", "README.rst", "docs/index.rst"]) +s.replace( + '**/client.py', + 'client_options: ClientOptions = ', + 'client_options: ClientOptions.ClientOptions = ' +) + +# https://github.com/googleapis/gapic-generator-python/issues/413 +s.replace( + 'google/cloud/aiplatform_v1alpha1/services/prediction_service/client.py', + 'request.instances = instances', + 'request.instances.extend(instances)' +) # ---------------------------------------------------------------------------- # Add templated files @@ -44,4 +83,4 @@ templated_files, excludes=[".coveragerc"] ) # the microgenerator has a good coveragerc file -s.shell.run(["nox", "-s", "blacken"], hide_output=False) \ No newline at end of file +s.shell.run(["nox", "-s", "blacken"], hide_output=False) diff --git a/tests/unit/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/aiplatform_v1beta1/test_dataset_service.py new file mode 100644 index 0000000000..bdf4410884 --- /dev/null +++ b/tests/unit/aiplatform_v1beta1/test_dataset_service.py @@ -0,0 +1,1140 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest import mock + +import grpc +import math +import pytest + +from google import auth +from google.api_core import client_options +from google.api_core import future +from google.api_core import operations_v1 +from google.auth import credentials +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers +from google.cloud.aiplatform_v1beta1.services.dataset_service import transports +from google.cloud.aiplatform_v1beta1.types import annotation +from google.cloud.aiplatform_v1beta1.types import annotation_spec +from google.cloud.aiplatform_v1beta1.types import data_item +from google.cloud.aiplatform_v1beta1.types import dataset +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def test_dataset_service_client_from_service_account_file(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = DatasetServiceClient.from_service_account_file("dummy/file/path.json") + assert client._transport._credentials == creds + + client = DatasetServiceClient.from_service_account_json("dummy/file/path.json") + assert client._transport._credentials == creds + + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_dataset_service_client_client_options(): + # Check the default options have their expected values. + assert ( + DatasetServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" + ) + + # Check that options can be customized. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.DatasetServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = DatasetServiceClient(client_options=options) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_dataset_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.DatasetServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = DatasetServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_create_dataset(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.CreateDatasetRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.create_dataset( + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + +def test_create_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_dataset( + dataset_service.CreateDatasetRequest(), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), + ) + + +def test_get_dataset(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.GetDatasetRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + + response = client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, dataset.Dataset) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.etag == "etag_value" + + +def test_get_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetDatasetRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: + call.return_value = dataset.Dataset() + client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_dataset( + dataset_service.GetDatasetRequest(), name="name_value", + ) + + +def test_update_dataset(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.UpdateDatasetRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + + response = client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_dataset.Dataset) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.etag == "etag_value" + + +def test_update_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.update_dataset( + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_dataset( + dataset_service.UpdateDatasetRequest(), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_list_datasets(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ListDatasetsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatasetsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_datasets_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDatasetsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: + call.return_value = dataset_service.ListDatasetsResponse() + client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_datasets_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_datasets(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_datasets_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_datasets( + dataset_service.ListDatasetsRequest(), parent="parent_value", + ) + + +def test_list_datasets_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + results = [i for i in client.list_datasets(request={},)] + assert len(results) == 6 + assert all(isinstance(i, dataset.Dataset) for i in results) + + +def test_list_datasets_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + pages = list(client.list_datasets(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_delete_dataset(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.DeleteDatasetRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_dataset( + dataset_service.DeleteDatasetRequest(), name="name_value", + ) + + +def test_import_data(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ImportDataRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_import_data_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.import_data( + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] + + +def test_import_data_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.import_data( + dataset_service.ImportDataRequest(), + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + +def test_export_data(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ExportDataRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_data_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.export_data( + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) + + +def test_export_data_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_data( + dataset_service.ExportDataRequest(), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + +def test_list_data_items(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ListDataItemsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataItemsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_data_items_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: + call.return_value = dataset_service.ListDataItemsResponse() + client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_data_items_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_data_items(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_data_items_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_items( + dataset_service.ListDataItemsRequest(), parent="parent_value", + ) + + +def test_list_data_items_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + results = [i for i in client.list_data_items(request={},)] + assert len(results) == 6 + assert all(isinstance(i, data_item.DataItem) for i in results) + + +def test_list_data_items_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + pages = list(client.list_data_items(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_get_annotation_spec(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.GetAnnotationSpecRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + + response = client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, annotation_spec.AnnotationSpec) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.etag == "etag_value" + + +def test_get_annotation_spec_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = annotation_spec.AnnotationSpec() + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_annotation_spec_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_annotation_spec(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_annotation_spec_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), name="name_value", + ) + + +def test_list_annotations(transport: str = "grpc"): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = dataset_service.ListAnnotationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_annotations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAnnotationsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_annotations_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_annotations), "__call__" + ) as call: + call.return_value = dataset_service.ListAnnotationsResponse() + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_annotations_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_annotations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_annotations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_annotations_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_annotations( + dataset_service.ListAnnotationsRequest(), parent="parent_value", + ) + + +def test_list_annotations_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_annotations), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + results = [i for i in client.list_annotations(request={},)] + assert len(results) == 6 + assert all(isinstance(i, annotation.Annotation) for i in results) + + +def test_list_annotations_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_annotations), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + pages = list(client.list_annotations(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = DatasetServiceClient(transport=transport) + assert client._transport is transport + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client._transport, transports.DatasetServiceGrpcTransport,) + + +def test_dataset_service_base_transport(): + # Instantiate the base transport. + transport = transports.DatasetServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_dataset", + "get_dataset", + "update_dataset", + "list_datasets", + "delete_dataset", + "import_data", + "export_data", + "list_data_items", + "get_annotation_spec", + "list_annotations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_dataset_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + DatasetServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_dataset_service_host_no_port(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_dataset_service_host_with_port(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:8000" + + +def test_dataset_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + transport = transports.DatasetServiceGrpcTransport(channel=channel,) + assert transport.grpc_channel is channel + + +def test_dataset_service_grpc_lro_client(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client._transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_dataset_path(): + project = "squid" + location = "clam" + dataset = "whelk" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = DatasetServiceClient.dataset_path(project, location, dataset) + assert expected == actual diff --git a/tests/unit/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/aiplatform_v1beta1/test_endpoint_service.py new file mode 100644 index 0000000000..4059cdb819 --- /dev/null +++ b/tests/unit/aiplatform_v1beta1/test_endpoint_service.py @@ -0,0 +1,771 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest import mock + +import grpc +import math +import pytest + +from google import auth +from google.api_core import client_options +from google.api_core import future +from google.api_core import operations_v1 +from google.auth import credentials +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports +from google.cloud.aiplatform_v1beta1.types import accelerator_type +from google.cloud.aiplatform_v1beta1.types import endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint_service +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def test_endpoint_service_client_from_service_account_file(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = EndpointServiceClient.from_service_account_file("dummy/file/path.json") + assert client._transport._credentials == creds + + client = EndpointServiceClient.from_service_account_json("dummy/file/path.json") + assert client._transport._credentials == creds + + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_endpoint_service_client_client_options(): + # Check the default options have their expected values. + assert ( + EndpointServiceClient.DEFAULT_OPTIONS.api_endpoint + == "aiplatform.googleapis.com" + ) + + # Check that options can be customized. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.EndpointServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = EndpointServiceClient(client_options=options) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_endpoint_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.EndpointServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = EndpointServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_create_endpoint(transport: str = "grpc"): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.CreateEndpointRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.create_endpoint( + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + +def test_create_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_endpoint( + endpoint_service.CreateEndpointRequest(), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + +def test_get_endpoint(transport: str = "grpc"): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.GetEndpointRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, endpoint.Endpoint) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.etag == "etag_value" + + +def test_get_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.GetEndpointRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: + call.return_value = endpoint.Endpoint() + client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_endpoint( + endpoint_service.GetEndpointRequest(), name="name_value", + ) + + +def test_list_endpoints(transport: str = "grpc"): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.ListEndpointsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListEndpointsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_endpoints_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.ListEndpointsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: + call.return_value = endpoint_service.ListEndpointsResponse() + client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_endpoints_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_endpoints(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_endpoints_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_endpoints( + endpoint_service.ListEndpointsRequest(), parent="parent_value", + ) + + +def test_list_endpoints_pager(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + results = [i for i in client.list_endpoints(request={},)] + assert len(results) == 6 + assert all(isinstance(i, endpoint.Endpoint) for i in results) + + +def test_list_endpoints_pages(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + pages = list(client.list_endpoints(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_update_endpoint(transport: str = "grpc"): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.UpdateEndpointRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_endpoint.Endpoint) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.etag == "etag_value" + + +def test_update_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.update_endpoint( + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_endpoint( + endpoint_service.UpdateEndpointRequest(), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_endpoint(transport: str = "grpc"): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.DeleteEndpointRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_endpoint( + endpoint_service.DeleteEndpointRequest(), name="name_value", + ) + + +def test_deploy_model(transport: str = "grpc"): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.DeployModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_deploy_model_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.deploy_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].endpoint == "endpoint_value" + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + assert args[0].traffic_split == {"key_value": 541} + + +def test_deploy_model_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.deploy_model( + endpoint_service.DeployModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + +def test_undeploy_model(transport: str = "grpc"): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = endpoint_service.UndeployModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_undeploy_model_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.undeploy_model( + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].endpoint == "endpoint_value" + assert args[0].deployed_model_id == "deployed_model_id_value" + assert args[0].traffic_split == {"key_value": 541} + + +def test_undeploy_model_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.undeploy_model( + endpoint_service.UndeployModelRequest(), + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = EndpointServiceClient(transport=transport) + assert client._transport is transport + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client._transport, transports.EndpointServiceGrpcTransport,) + + +def test_endpoint_service_base_transport(): + # Instantiate the base transport. + transport = transports.EndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_endpoint", + "get_endpoint", + "list_endpoints", + "update_endpoint", + "delete_endpoint", + "deploy_model", + "undeploy_model", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_endpoint_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + EndpointServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_endpoint_service_host_no_port(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_endpoint_service_host_with_port(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:8000" + + +def test_endpoint_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + transport = transports.EndpointServiceGrpcTransport(channel=channel,) + assert transport.grpc_channel is channel + + +def test_endpoint_service_grpc_lro_client(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client._transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + actual = EndpointServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual diff --git a/tests/unit/aiplatform_v1beta1/test_job_service.py b/tests/unit/aiplatform_v1beta1/test_job_service.py new file mode 100644 index 0000000000..9072ccd396 --- /dev/null +++ b/tests/unit/aiplatform_v1beta1/test_job_service.py @@ -0,0 +1,2118 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest import mock + +import grpc +import math +import pytest + +from google import auth +from google.api_core import client_options +from google.api_core import future +from google.api_core import operations_v1 +from google.auth import credentials +from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient +from google.cloud.aiplatform_v1beta1.services.job_service import pagers +from google.cloud.aiplatform_v1beta1.services.job_service import transports +from google.cloud.aiplatform_v1beta1.types import accelerator_type +from google.cloud.aiplatform_v1beta1.types import ( + accelerator_type as gca_accelerator_type, +) +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1beta1.types import completion_stats +from google.cloud.aiplatform_v1beta1.types import ( + completion_stats as gca_completion_stats, +) +from google.cloud.aiplatform_v1beta1.types import custom_job +from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import job_state +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import ( + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, +) +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import study +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 as any # type: ignore +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + + +def test_job_service_client_from_service_account_file(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = JobServiceClient.from_service_account_file("dummy/file/path.json") + assert client._transport._credentials == creds + + client = JobServiceClient.from_service_account_json("dummy/file/path.json") + assert client._transport._credentials == creds + + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_job_service_client_client_options(): + # Check the default options have their expected values. + assert JobServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" + + # Check that options can be customized. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = JobServiceClient(client_options=options) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_job_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_create_custom_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CreateCustomJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_custom_job.CustomJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.create_custom_job( + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + + +def test_create_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_custom_job( + job_service.CreateCustomJobRequest(), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + +def test_get_custom_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.GetCustomJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, custom_job.CustomJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetCustomJobRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: + call.return_value = custom_job.CustomJob() + client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_custom_job( + job_service.GetCustomJobRequest(), name="name_value", + ) + + +def test_list_custom_jobs(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.ListCustomJobsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_custom_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCustomJobsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_custom_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListCustomJobsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_custom_jobs), "__call__" + ) as call: + call.return_value = job_service.ListCustomJobsResponse() + client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_custom_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_custom_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_custom_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_custom_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_custom_jobs( + job_service.ListCustomJobsRequest(), parent="parent_value", + ) + + +def test_list_custom_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_custom_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + results = [i for i in client.list_custom_jobs(request={},)] + assert len(results) == 6 + assert all(isinstance(i, custom_job.CustomJob) for i in results) + + +def test_list_custom_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_custom_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + pages = list(client.list_custom_jobs(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_delete_custom_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.DeleteCustomJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_custom_job( + job_service.DeleteCustomJobRequest(), name="name_value", + ) + + +def test_cancel_custom_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CancelCustomJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.cancel_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_cancel_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_custom_job( + job_service.CancelCustomJobRequest(), name="name_value", + ) + + +def test_create_data_labeling_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CreateDataLabelingJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + + response = client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_data_labeling_job.DataLabelingJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.datasets == ["datasets_value"] + assert response.labeler_count == 1375 + assert response.instruction_uri == "instruction_uri_value" + assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.state == job_state.JobState.JOB_STATE_QUEUED + assert response.labeling_progress == 1810 + assert response.specialist_pools == ["specialist_pools_value"] + + +def test_create_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.create_data_labeling_job( + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) + + +def test_create_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_data_labeling_job( + job_service.CreateDataLabelingJobRequest(), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + +def test_get_data_labeling_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.GetDataLabelingJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + + response = client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, data_labeling_job.DataLabelingJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.datasets == ["datasets_value"] + assert response.labeler_count == 1375 + assert response.instruction_uri == "instruction_uri_value" + assert response.inputs_schema_uri == "inputs_schema_uri_value" + assert response.state == job_state.JobState.JOB_STATE_QUEUED + assert response.labeling_progress == 1810 + assert response.specialist_pools == ["specialist_pools_value"] + + +def test_get_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetDataLabelingJobRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = data_labeling_job.DataLabelingJob() + client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_data_labeling_job( + job_service.GetDataLabelingJobRequest(), name="name_value", + ) + + +def test_list_data_labeling_jobs(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.ListDataLabelingJobsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataLabelingJobsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_data_labeling_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListDataLabelingJobsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = job_service.ListDataLabelingJobsResponse() + client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_data_labeling_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_data_labeling_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_data_labeling_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_labeling_jobs( + job_service.ListDataLabelingJobsRequest(), parent="parent_value", + ) + + +def test_list_data_labeling_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_data_labeling_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + results = [i for i in client.list_data_labeling_jobs(request={},)] + assert len(results) == 6 + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) + + +def test_list_data_labeling_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_data_labeling_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_data_labeling_jobs(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_delete_data_labeling_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.DeleteDataLabelingJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_data_labeling_job( + job_service.DeleteDataLabelingJobRequest(), name="name_value", + ) + + +def test_cancel_data_labeling_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CancelDataLabelingJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.cancel_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_cancel_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_data_labeling_job( + job_service.CancelDataLabelingJobRequest(), name="name_value", + ) + + +def test_create_hyperparameter_tuning_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CreateHyperparameterTuningJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.max_trial_count == 1609 + assert response.parallel_trial_count == 2128 + assert response.max_failed_trial_count == 2317 + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.create_hyperparameter_tuning_job( + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) + + +def test_create_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_hyperparameter_tuning_job( + job_service.CreateHyperparameterTuningJobRequest(), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + +def test_get_hyperparameter_tuning_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.GetHyperparameterTuningJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.max_trial_count == 1609 + assert response.parallel_trial_count == 2128 + assert response.max_failed_trial_count == 2317 + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetHyperparameterTuningJobRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_hyperparameter_tuning_job( + job_service.GetHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_list_hyperparameter_tuning_jobs(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.ListHyperparameterTuningJobsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_hyperparameter_tuning_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListHyperparameterTuningJobsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_hyperparameter_tuning_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_hyperparameter_tuning_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_hyperparameter_tuning_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_hyperparameter_tuning_jobs( + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + ) + + +def test_list_hyperparameter_tuning_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + results = [i for i in client.list_hyperparameter_tuning_jobs(request={},)] + assert len(results) == 6 + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results + ) + + +def test_list_hyperparameter_tuning_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_delete_hyperparameter_tuning_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.DeleteHyperparameterTuningJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_hyperparameter_tuning_job( + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_cancel_hyperparameter_tuning_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CancelHyperparameterTuningJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.cancel_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_cancel_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_hyperparameter_tuning_job( + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_create_batch_prediction_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CreateBatchPredictionJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + assert response.generate_explanation is True + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.create_batch_prediction_job( + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) + + +def test_create_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_batch_prediction_job( + job_service.CreateBatchPredictionJobRequest(), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + +def test_get_batch_prediction_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.GetBatchPredictionJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, batch_prediction_job.BatchPredictionJob) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + assert response.generate_explanation is True + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetBatchPredictionJobRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = batch_prediction_job.BatchPredictionJob() + client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_batch_prediction_job( + job_service.GetBatchPredictionJobRequest(), name="name_value", + ) + + +def test_list_batch_prediction_jobs(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.ListBatchPredictionJobsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBatchPredictionJobsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_batch_prediction_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListBatchPredictionJobsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = job_service.ListBatchPredictionJobsResponse() + client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_batch_prediction_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_batch_prediction_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_batch_prediction_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_batch_prediction_jobs( + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + ) + + +def test_list_batch_prediction_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + results = [i for i in client.list_batch_prediction_jobs(request={},)] + assert len(results) == 6 + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results + ) + + +def test_list_batch_prediction_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_batch_prediction_jobs(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_delete_batch_prediction_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.DeleteBatchPredictionJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), name="name_value", + ) + + +def test_cancel_batch_prediction_job(transport: str = "grpc"): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = job_service.CancelBatchPredictionJobRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.cancel_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_cancel_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = JobServiceClient(transport=transport) + assert client._transport is transport + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client._transport, transports.JobServiceGrpcTransport,) + + +def test_job_service_base_transport(): + # Instantiate the base transport. + transport = transports.JobServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_custom_job", + "get_custom_job", + "list_custom_jobs", + "delete_custom_job", + "cancel_custom_job", + "create_data_labeling_job", + "get_data_labeling_job", + "list_data_labeling_jobs", + "delete_data_labeling_job", + "cancel_data_labeling_job", + "create_hyperparameter_tuning_job", + "get_hyperparameter_tuning_job", + "list_hyperparameter_tuning_jobs", + "delete_hyperparameter_tuning_job", + "cancel_hyperparameter_tuning_job", + "create_batch_prediction_job", + "get_batch_prediction_job", + "list_batch_prediction_jobs", + "delete_batch_prediction_job", + "cancel_batch_prediction_job", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_job_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + JobServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_job_service_host_no_port(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_job_service_host_with_port(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:8000" + + +def test_job_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + transport = transports.JobServiceGrpcTransport(channel=channel,) + assert transport.grpc_channel is channel + + +def test_job_service_grpc_lro_client(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client._transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_hyperparameter_tuning_job_path(): + project = "squid" + location = "clam" + hyperparameter_tuning_job = "whelk" + + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) + actual = JobServiceClient.hyperparameter_tuning_job_path( + project, location, hyperparameter_tuning_job + ) + assert expected == actual + + +def test_custom_job_path(): + project = "squid" + location = "clam" + custom_job = "whelk" + + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + actual = JobServiceClient.custom_job_path(project, location, custom_job) + assert expected == actual + + +def test_data_labeling_job_path(): + project = "squid" + location = "clam" + data_labeling_job = "whelk" + + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + actual = JobServiceClient.data_labeling_job_path( + project, location, data_labeling_job + ) + assert expected == actual + + +def test_batch_prediction_job_path(): + project = "squid" + location = "clam" + batch_prediction_job = "whelk" + + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, location=location, batch_prediction_job=batch_prediction_job, + ) + actual = JobServiceClient.batch_prediction_job_path( + project, location, batch_prediction_job + ) + assert expected == actual diff --git a/tests/unit/aiplatform_v1beta1/test_model_service.py b/tests/unit/aiplatform_v1beta1/test_model_service.py new file mode 100644 index 0000000000..854272f8a5 --- /dev/null +++ b/tests/unit/aiplatform_v1beta1/test_model_service.py @@ -0,0 +1,1223 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest import mock + +import grpc +import math +import pytest + +from google import auth +from google.api_core import client_options +from google.api_core import future +from google.api_core import operations_v1 +from google.auth import credentials +from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient +from google.cloud.aiplatform_v1beta1.services.model_service import pagers +from google.cloud.aiplatform_v1beta1.services.model_service import transports +from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import env_var +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model as gca_model +from google.cloud.aiplatform_v1beta1.types import model_evaluation +from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice +from google.cloud.aiplatform_v1beta1.types import model_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def test_model_service_client_from_service_account_file(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = ModelServiceClient.from_service_account_file("dummy/file/path.json") + assert client._transport._credentials == creds + + client = ModelServiceClient.from_service_account_json("dummy/file/path.json") + assert client._transport._credentials == creds + + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_model_service_client_client_options(): + # Check the default options have their expected values. + assert ( + ModelServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" + ) + + # Check that options can be customized. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.ModelServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = ModelServiceClient(client_options=options) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_model_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.ModelServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_upload_model(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.UploadModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_upload_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.upload_model( + parent="parent_value", model=gca_model.Model(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].model == gca_model.Model(name="name_value") + + +def test_upload_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.upload_model( + model_service.UploadModelRequest(), + parent="parent_value", + model=gca_model.Model(name="name_value"), + ) + + +def test_get_model(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.GetModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", + ) + + response = client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.training_pipeline == "training_pipeline_value" + assert response.artifact_uri == "artifact_uri_value" + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + assert response.etag == "etag_value" + + +def test_get_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_model), "__call__") as call: + call.return_value = model.Model() + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), name="name_value", + ) + + +def test_list_models(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.ListModelsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_models_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_models), "__call__") as call: + call.return_value = model_service.ListModelsResponse() + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_models_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_models(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_models_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), parent="parent_value", + ) + + +def test_list_models_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_models), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + results = [i for i in client.list_models(request={},)] + assert len(results) == 6 + assert all(isinstance(i, model.Model) for i in results) + + +def test_list_models_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.list_models), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + pages = list(client.list_models(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_update_model(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.UpdateModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", + ) + + response = client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_model.Model) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.metadata_schema_uri == "metadata_schema_uri_value" + assert response.training_pipeline == "training_pipeline_value" + assert response.artifact_uri == "artifact_uri_value" + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + assert response.etag == "etag_value" + + +def test_update_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.update_model( + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].model == gca_model.Model(name="name_value") + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_model( + model_service.UpdateModelRequest(), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_model(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.DeleteModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_model( + model_service.DeleteModelRequest(), name="name_value", + ) + + +def test_export_model(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.ExportModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.export_model( + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) + + +def test_export_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_model( + model_service.ExportModelRequest(), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + +def test_get_model_evaluation(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.GetModelEvaluationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + + response = client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation.ModelEvaluation) + assert response.name == "name_value" + assert response.metrics_schema_uri == "metrics_schema_uri_value" + assert response.slice_dimensions == ["slice_dimensions_value"] + + +def test_get_model_evaluation_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = model_evaluation.ModelEvaluation() + client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_evaluation_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_model_evaluation(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_model_evaluation_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_evaluation( + model_service.GetModelEvaluationRequest(), name="name_value", + ) + + +def test_list_model_evaluations(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.ListModelEvaluationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_model_evaluations_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = model_service.ListModelEvaluationsResponse() + client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_model_evaluations_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_model_evaluations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_model_evaluations_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_evaluations( + model_service.ListModelEvaluationsRequest(), parent="parent_value", + ) + + +def test_list_model_evaluations_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluations), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + results = [i for i in client.list_model_evaluations(request={},)] + assert len(results) == 6 + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) + + +def test_list_model_evaluations_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluations), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_evaluations(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_get_model_evaluation_slice(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.GetModelEvaluationSliceRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + + response = client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) + assert response.name == "name_value" + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + +def test_get_model_evaluation_slice_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationSliceRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_evaluation_slice_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_model_evaluation_slice(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_model_evaluation_slice_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_evaluation_slice( + model_service.GetModelEvaluationSliceRequest(), name="name_value", + ) + + +def test_list_model_evaluation_slices(transport: str = "grpc"): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = model_service.ListModelEvaluationSlicesRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationSlicesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_model_evaluation_slices_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationSlicesRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = model_service.ListModelEvaluationSlicesResponse() + client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_model_evaluation_slices_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_model_evaluation_slices(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_model_evaluation_slices_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_evaluation_slices( + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + ) + + +def test_list_model_evaluation_slices_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluation_slices), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + results = [i for i in client.list_model_evaluation_slices(request={},)] + assert len(results) == 6 + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results + ) + + +def test_list_model_evaluation_slices_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_model_evaluation_slices), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_evaluation_slices(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = ModelServiceClient(transport=transport) + assert client._transport is transport + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client._transport, transports.ModelServiceGrpcTransport,) + + +def test_model_service_base_transport(): + # Instantiate the base transport. + transport = transports.ModelServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "upload_model", + "get_model", + "list_models", + "update_model", + "delete_model", + "export_model", + "get_model_evaluation", + "list_model_evaluations", + "get_model_evaluation_slice", + "list_model_evaluation_slices", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_model_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + ModelServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_model_service_host_no_port(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_model_service_host_with_port(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:8000" + + +def test_model_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + transport = transports.ModelServiceGrpcTransport(channel=channel,) + assert transport.grpc_channel is channel + + +def test_model_service_grpc_lro_client(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client._transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_model_path(): + project = "squid" + location = "clam" + model = "whelk" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = ModelServiceClient.model_path(project, location, model) + assert expected == actual diff --git a/tests/unit/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/aiplatform_v1beta1/test_pipeline_service.py new file mode 100644 index 0000000000..842564c259 --- /dev/null +++ b/tests/unit/aiplatform_v1beta1/test_pipeline_service.py @@ -0,0 +1,675 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest import mock + +import grpc +import math +import pytest + +from google import auth +from google.api_core import client_options +from google.api_core import future +from google.api_core import operations_v1 +from google.auth import credentials +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports +from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import env_var +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import pipeline_service +from google.cloud.aiplatform_v1beta1.types import pipeline_state +from google.cloud.aiplatform_v1beta1.types import training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 as any # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +def test_pipeline_service_client_from_service_account_file(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = PipelineServiceClient.from_service_account_file("dummy/file/path.json") + assert client._transport._credentials == creds + + client = PipelineServiceClient.from_service_account_json("dummy/file/path.json") + assert client._transport._credentials == creds + + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_pipeline_service_client_client_options(): + # Check the default options have their expected values. + assert ( + PipelineServiceClient.DEFAULT_OPTIONS.api_endpoint + == "aiplatform.googleapis.com" + ) + + # Check that options can be customized. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.PipelineServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = PipelineServiceClient(client_options=options) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_pipeline_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.PipelineServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = PipelineServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_create_training_pipeline(transport: str = "grpc"): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.CreateTrainingPipelineRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + + response = client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_training_pipeline.TrainingPipeline) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.training_task_definition == "training_task_definition_value" + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_create_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.create_training_pipeline( + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) + + +def test_create_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_training_pipeline( + pipeline_service.CreateTrainingPipelineRequest(), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + +def test_get_training_pipeline(transport: str = "grpc"): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.GetTrainingPipelineRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + + response = client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, training_pipeline.TrainingPipeline) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.training_task_definition == "training_task_definition_value" + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_get_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetTrainingPipelineRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = training_pipeline.TrainingPipeline() + client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_training_pipeline( + pipeline_service.GetTrainingPipelineRequest(), name="name_value", + ) + + +def test_list_training_pipelines(transport: str = "grpc"): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.ListTrainingPipelinesRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTrainingPipelinesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_training_pipelines_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListTrainingPipelinesRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_training_pipelines_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_training_pipelines(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_training_pipelines_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_training_pipelines( + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + ) + + +def test_list_training_pipelines_pager(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_training_pipelines), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + results = [i for i in client.list_training_pipelines(request={},)] + assert len(results) == 6 + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) + + +def test_list_training_pipelines_pages(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_training_pipelines), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + pages = list(client.list_training_pipelines(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_delete_training_pipeline(transport: str = "grpc"): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.DeleteTrainingPipelineRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_training_pipeline( + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + ) + + +def test_cancel_training_pipeline(transport: str = "grpc"): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = pipeline_service.CancelTrainingPipelineRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.cancel_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_cancel_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_training_pipeline( + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = PipelineServiceClient(transport=transport) + assert client._transport is transport + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client._transport, transports.PipelineServiceGrpcTransport,) + + +def test_pipeline_service_base_transport(): + # Instantiate the base transport. + transport = transports.PipelineServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_training_pipeline", + "get_training_pipeline", + "list_training_pipelines", + "delete_training_pipeline", + "cancel_training_pipeline", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_pipeline_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + PipelineServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_pipeline_service_host_no_port(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_pipeline_service_host_with_port(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:8000" + + +def test_pipeline_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + transport = transports.PipelineServiceGrpcTransport(channel=channel,) + assert transport.grpc_channel is channel + + +def test_pipeline_service_grpc_lro_client(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client._transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_training_pipeline_path(): + project = "squid" + location = "clam" + training_pipeline = "whelk" + + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = PipelineServiceClient.training_pipeline_path( + project, location, training_pipeline + ) + assert expected == actual + + +def test_model_path(): + project = "squid" + location = "clam" + model = "whelk" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = PipelineServiceClient.model_path(project, location, model) + assert expected == actual diff --git a/tests/unit/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/aiplatform_v1beta1/test_prediction_service.py new file mode 100644 index 0000000000..96614bed70 --- /dev/null +++ b/tests/unit/aiplatform_v1beta1/test_prediction_service.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest import mock + +import grpc +import math +import pytest + +from google import auth +from google.api_core import client_options +from google.auth import credentials +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + PredictionServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import transports +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import prediction_service +from google.oauth2 import service_account +from google.protobuf import struct_pb2 as struct # type: ignore + + +def test_prediction_service_client_from_service_account_file(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = PredictionServiceClient.from_service_account_file( + "dummy/file/path.json" + ) + assert client._transport._credentials == creds + + client = PredictionServiceClient.from_service_account_json( + "dummy/file/path.json" + ) + assert client._transport._credentials == creds + + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_prediction_service_client_client_options(): + # Check the default options have their expected values. + assert ( + PredictionServiceClient.DEFAULT_OPTIONS.api_endpoint + == "aiplatform.googleapis.com" + ) + + # Check that options can be customized. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.PredictionServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = PredictionServiceClient(client_options=options) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_prediction_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.PredictionServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = PredictionServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_predict(transport: str = "grpc"): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = prediction_service.PredictRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse( + deployed_model_id="deployed_model_id_value", + ) + + response = client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.PredictResponse) + assert response.deployed_model_id == "deployed_model_id_value" + + +def test_predict_flattened(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.predict( + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].endpoint == "endpoint_value" + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] + assert args[0].parameters == struct.Value( + null_value=struct.NullValue.NULL_VALUE + ) + + +def test_predict_flattened_error(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.predict( + prediction_service.PredictRequest(), + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + +def test_explain(transport: str = "grpc"): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = prediction_service.ExplainRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.explain), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.ExplainResponse( + deployed_model_id="deployed_model_id_value", + ) + + response = client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.ExplainResponse) + assert response.deployed_model_id == "deployed_model_id_value" + + +def test_explain_flattened(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client._transport.explain), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.ExplainResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.explain( + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id="deployed_model_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].endpoint == "endpoint_value" + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] + assert args[0].parameters == struct.Value( + null_value=struct.NullValue.NULL_VALUE + ) + assert args[0].deployed_model_id == "deployed_model_id_value" + + +def test_explain_flattened_error(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.explain( + prediction_service.ExplainRequest(), + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id="deployed_model_id_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = PredictionServiceClient(transport=transport) + assert client._transport is transport + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client._transport, transports.PredictionServiceGrpcTransport,) + + +def test_prediction_service_base_transport(): + # Instantiate the base transport. + transport = transports.PredictionServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "predict", + "explain", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + +def test_prediction_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + PredictionServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_prediction_service_host_no_port(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_prediction_service_host_with_port(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:8000" + + +def test_prediction_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + transport = transports.PredictionServiceGrpcTransport(channel=channel,) + assert transport.grpc_channel is channel diff --git a/tests/unit/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/aiplatform_v1beta1/test_specialist_pool_service.py new file mode 100644 index 0000000000..f8467edf33 --- /dev/null +++ b/tests/unit/aiplatform_v1beta1/test_specialist_pool_service.py @@ -0,0 +1,681 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest import mock + +import grpc +import math +import pytest + +from google import auth +from google.api_core import client_options +from google.api_core import future +from google.api_core import operations_v1 +from google.auth import credentials +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool_service +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +def test_specialist_pool_service_client_from_service_account_file(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = SpecialistPoolServiceClient.from_service_account_file( + "dummy/file/path.json" + ) + assert client._transport._credentials == creds + + client = SpecialistPoolServiceClient.from_service_account_json( + "dummy/file/path.json" + ) + assert client._transport._credentials == creds + + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_specialist_pool_service_client_client_options(): + # Check the default options have their expected values. + assert ( + SpecialistPoolServiceClient.DEFAULT_OPTIONS.api_endpoint + == "aiplatform.googleapis.com" + ) + + # Check that options can be customized. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.SpecialistPoolServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = SpecialistPoolServiceClient(client_options=options) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_specialist_pool_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.SpecialistPoolServiceClient.get_transport_class" + ) as gtc: + transport = gtc.return_value = mock.MagicMock() + client = SpecialistPoolServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") + + +def test_create_specialist_pool(transport: str = "grpc"): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.CreateSpecialistPoolRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.create_specialist_pool( + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + +def test_create_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_specialist_pool( + specialist_pool_service.CreateSpecialistPoolRequest(), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + +def test_get_specialist_pool(transport: str = "grpc"): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.GetSpecialistPoolRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + + response = client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, specialist_pool.SpecialistPool) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.specialist_managers_count == 2662 + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + + +def test_get_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.GetSpecialistPoolRequest(name="name/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = specialist_pool.SpecialistPool() + client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.get_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_get_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_specialist_pool( + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + ) + + +def test_list_specialist_pools(transport: str = "grpc"): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.ListSpecialistPoolsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListSpecialistPoolsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_specialist_pools_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.ListSpecialistPoolsRequest(parent="parent/value",) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_specialist_pools_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.list_specialist_pools(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].parent == "parent_value" + + +def test_list_specialist_pools_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_specialist_pools( + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + ) + + +def test_list_specialist_pools_pager(): + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_specialist_pools), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + results = [i for i in client.list_specialist_pools(request={},)] + assert len(results) == 6 + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) + + +def test_list_specialist_pools_pages(): + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.list_specialist_pools), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + pages = list(client.list_specialist_pools(request={}).pages) + for page, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page.raw_page.next_page_token == token + + +def test_delete_specialist_pool(transport: str = "grpc"): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.DeleteSpecialistPoolRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.delete_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].name == "name_value" + + +def test_delete_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_specialist_pool( + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + ) + + +def test_update_specialist_pool(transport: str = "grpc"): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = specialist_pool_service.UpdateSpecialistPoolRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client._transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = client.update_specialist_pool( + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_specialist_pool( + specialist_pool_service.UpdateSpecialistPoolRequest(), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = SpecialistPoolServiceClient(transport=transport) + assert client._transport is transport + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance(client._transport, transports.SpecialistPoolServiceGrpcTransport,) + + +def test_specialist_pool_service_base_transport(): + # Instantiate the base transport. + transport = transports.SpecialistPoolServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_specialist_pool", + "get_specialist_pool", + "list_specialist_pools", + "delete_specialist_pool", + "update_specialist_pool", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_specialist_pool_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + SpecialistPoolServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",) + ) + + +def test_specialist_pool_service_host_no_port(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:443" + + +def test_specialist_pool_service_host_with_port(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + transport="grpc", + ) + assert client._transport._host == "aiplatform.googleapis.com:8000" + + +def test_specialist_pool_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + transport = transports.SpecialistPoolServiceGrpcTransport(channel=channel,) + assert transport.grpc_channel is channel + + +def test_specialist_pool_service_grpc_lro_client(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client._transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_specialist_pool_path(): + project = "squid" + location = "clam" + specialist_pool = "whelk" + + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) + actual = SpecialistPoolServiceClient.specialist_pool_path( + project, location, specialist_pool + ) + assert expected == actual diff --git a/tests/unit/gapic/aiplatform_v1beta1/__init__.py b/tests/unit/gapic/aiplatform_v1beta1/__init__.py deleted file mode 100644 index 8b13789179..0000000000 --- a/tests/unit/gapic/aiplatform_v1beta1/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py deleted file mode 100644 index 35996fd5c4..0000000000 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ /dev/null @@ -1,3299 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import mock - -import grpc -from grpc.experimental import aio -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule - -from google import auth -from google.api_core import client_options -from google.api_core import exceptions -from google.api_core import future -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import operation_async # type: ignore -from google.api_core import operations_v1 -from google.auth import credentials -from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - DatasetServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - DatasetServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers -from google.cloud.aiplatform_v1beta1.services.dataset_service import transports -from google.cloud.aiplatform_v1beta1.types import annotation -from google.cloud.aiplatform_v1beta1.types import annotation_spec -from google.cloud.aiplatform_v1beta1.types import data_item -from google.cloud.aiplatform_v1beta1.types import dataset -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.types import dataset_service -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert DatasetServiceClient._get_default_mtls_endpoint(None) is None - assert ( - DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) - - -@pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient] -) -def test_dataset_service_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_dataset_service_client_get_transport_class(): - transport = DatasetServiceClient.get_transport_class() - assert transport == transports.DatasetServiceGrpcTransport - - transport = DatasetServiceClient.get_transport_class("grpc") - assert transport == transports.DatasetServiceGrpcTransport - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - DatasetServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceClient), -) -@mock.patch.object( - DatasetServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceAsyncClient), -) -def test_dataset_service_client_client_options( - client_class, transport_class, transport_name -): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - DatasetServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceClient), -) -@mock.patch.object( - DatasetServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(DatasetServiceAsyncClient), -) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_dataset_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) - - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_dataset_service_client_client_options_scopes( - client_class, transport_class, transport_name -): - # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), - ( - DatasetServiceAsyncClient, - transports.DatasetServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_dataset_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): - # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_dataset_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = DatasetServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_create_dataset( - transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.CreateDatasetRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_dataset_from_dict(): - test_create_dataset(request_type=dict) - - -@pytest.mark.asyncio -async def test_create_dataset_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.CreateDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_dataset), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.create_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.CreateDatasetRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.create_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_create_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.CreateDatasetRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_dataset), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.create_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_create_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].dataset == gca_dataset.Dataset(name="name_value") - - -def test_create_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_dataset( - dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), - ) - - -@pytest.mark.asyncio -async def test_create_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_dataset), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].dataset == gca_dataset.Dataset(name="name_value") - - -@pytest.mark.asyncio -async def test_create_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_dataset( - dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), - ) - - -def test_get_dataset( - transport: str = "grpc", request_type=dataset_service.GetDatasetRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - - response = client.get_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.GetDatasetRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, dataset.Dataset) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.metadata_schema_uri == "metadata_schema_uri_value" - - assert response.etag == "etag_value" - - -def test_get_dataset_from_dict(): - test_get_dataset(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_dataset_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.GetDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_dataset), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - ) - - response = await client.get_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, dataset.Dataset) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.metadata_schema_uri == "metadata_schema_uri_value" - - assert response.etag == "etag_value" - - -def test_get_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetDatasetRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - call.return_value = dataset.Dataset() - - client.get_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetDatasetRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_dataset), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) - - await client.get_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset.Dataset() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_dataset(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_dataset), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset.Dataset() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_dataset(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", - ) - - -def test_update_dataset( - transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - - response = client.update_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.UpdateDatasetRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_dataset.Dataset) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.metadata_schema_uri == "metadata_schema_uri_value" - - assert response.etag == "etag_value" - - -def test_update_dataset_from_dict(): - test_update_dataset(request_type=dict) - - -@pytest.mark.asyncio -async def test_update_dataset_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.UpdateDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_dataset), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - ) - - response = await client.update_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_dataset.Dataset) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.metadata_schema_uri == "metadata_schema_uri_value" - - assert response.etag == "etag_value" - - -def test_update_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.UpdateDatasetRequest() - request.dataset.name = "dataset.name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: - call.return_value = gca_dataset.Dataset() - - client.update_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ - "metadata" - ] - - -@pytest.mark.asyncio -async def test_update_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.UpdateDatasetRequest() - request.dataset.name = "dataset.name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_dataset), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) - - await client.update_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ - "metadata" - ] - - -def test_update_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_dataset.Dataset() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].dataset == gca_dataset.Dataset(name="name_value") - - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_dataset( - dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -@pytest.mark.asyncio -async def test_update_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_dataset), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_dataset.Dataset() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].dataset == gca_dataset.Dataset(name="name_value") - - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -@pytest.mark.asyncio -async def test_update_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.update_dataset( - dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_list_datasets( - transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_datasets(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.ListDatasetsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDatasetsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_datasets_from_dict(): - test_list_datasets(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_datasets_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListDatasetsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_datasets), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_datasets(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDatasetsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_datasets_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListDatasetsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - call.return_value = dataset_service.ListDatasetsResponse() - - client.list_datasets(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_datasets_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListDatasetsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_datasets), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse() - ) - - await client.list_datasets(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_datasets_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_datasets(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_datasets_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_datasets_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_datasets), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDatasetsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_datasets(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_datasets_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", - ) - - -def test_list_datasets_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_datasets(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) for i in results) - - -def test_list_datasets_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], - ), - RuntimeError, - ) - pages = list(client.list_datasets(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_datasets_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_datasets), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], - ), - RuntimeError, - ) - async_pager = await client.list_datasets(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, dataset.Dataset) for i in responses) - - -@pytest.mark.asyncio -async def test_list_datasets_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_datasets), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_datasets(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_delete_dataset( - transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.DeleteDatasetRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_dataset_from_dict(): - test_delete_dataset(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_dataset_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.DeleteDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_dataset), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.DeleteDatasetRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_dataset_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.DeleteDatasetRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_dataset), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_dataset(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_dataset_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_dataset), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_dataset(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_dataset_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", - ) - - -def test_import_data( - transport: str = "grpc", request_type=dataset_service.ImportDataRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.import_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.import_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.ImportDataRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_import_data_from_dict(): - test_import_data(request_type=dict) - - -@pytest.mark.asyncio -async def test_import_data_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ImportDataRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.import_data), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.import_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_import_data_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ImportDataRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.import_data), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.import_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_import_data_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ImportDataRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.import_data), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.import_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_import_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.import_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] - - -def test_import_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.import_data( - dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], - ) - - -@pytest.mark.asyncio -async def test_import_data_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.import_data), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] - - -@pytest.mark.asyncio -async def test_import_data_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.import_data( - dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], - ) - - -def test_export_data( - transport: str = "grpc", request_type=dataset_service.ExportDataRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.export_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.ExportDataRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_export_data_from_dict(): - test_export_data(request_type=dict) - - -@pytest.mark.asyncio -async def test_export_data_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ExportDataRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.export_data), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.export_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_export_data_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ExportDataRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_data), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.export_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_export_data_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ExportDataRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.export_data), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.export_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_export_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) - - -def test_export_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.export_data( - dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), - ) - - -@pytest.mark.asyncio -async def test_export_data_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.export_data), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) - - -@pytest.mark.asyncio -async def test_export_data_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.export_data( - dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), - ) - - -def test_list_data_items( - transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_data_items(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.ListDataItemsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataItemsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_data_items_from_dict(): - test_list_data_items(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_data_items_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListDataItemsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_items), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_data_items(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataItemsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_data_items_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListDataItemsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - call.return_value = dataset_service.ListDataItemsResponse() - - client.list_data_items(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_data_items_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListDataItemsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_items), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() - ) - - await client.list_data_items(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_data_items_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_data_items(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_data_items_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_data_items_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_items), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListDataItemsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_data_items(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_data_items_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", - ) - - -def test_list_data_items_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_data_items(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) for i in results) - - -def test_list_data_items_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], - ), - RuntimeError, - ) - pages = list(client.list_data_items(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_data_items_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_items), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], - ), - RuntimeError, - ) - async_pager = await client.list_data_items(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, data_item.DataItem) for i in responses) - - -@pytest.mark.asyncio -async def test_list_data_items_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_items), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_data_items(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_get_annotation_spec( - transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", - ) - - response = client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.GetAnnotationSpecRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, annotation_spec.AnnotationSpec) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.etag == "etag_value" - - -def test_get_annotation_spec_from_dict(): - test_get_annotation_spec(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_annotation_spec_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.GetAnnotationSpecRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_annotation_spec), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", - ) - ) - - response = await client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, annotation_spec.AnnotationSpec) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.etag == "etag_value" - - -def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetAnnotationSpecRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - call.return_value = annotation_spec.AnnotationSpec() - - client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_annotation_spec_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetAnnotationSpecRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_annotation_spec), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) - - await client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_annotation_spec_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_annotation_spec(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_annotation_spec_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_annotation_spec), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - annotation_spec.AnnotationSpec() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_annotation_spec(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_annotation_spec_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", - ) - - -def test_list_annotations( - transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest -): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == dataset_service.ListAnnotationsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListAnnotationsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_annotations_from_dict(): - test_list_annotations(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_annotations_async(transport: str = "grpc_asyncio"): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListAnnotationsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_annotations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListAnnotationsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_annotations_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListAnnotationsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - call.return_value = dataset_service.ListAnnotationsResponse() - - client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_annotations_field_headers_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListAnnotationsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_annotations), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() - ) - - await client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_annotations_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_annotations(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_annotations_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_annotations_flattened_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_annotations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - dataset_service.ListAnnotationsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_annotations(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_annotations_flattened_error_async(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", - ) - - -def test_list_annotations_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_annotations(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) for i in results) - - -def test_list_annotations_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], - ), - RuntimeError, - ) - pages = list(client.list_annotations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_annotations_async_pager(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_annotations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], - ), - RuntimeError, - ) - async_pager = await client.list_annotations(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, annotation.Annotation) for i in responses) - - -@pytest.mark.asyncio -async def test_list_annotations_async_pages(): - client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_annotations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_annotations(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DatasetServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DatasetServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = DatasetServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.DatasetServiceGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.DatasetServiceGrpcTransport,) - - -def test_dataset_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): - transport = transports.DatasetServiceTransport( - credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", - ) - - -def test_dataset_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.DatasetServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_dataset", - "get_dataset", - "update_dataset", - "list_datasets", - "delete_dataset", - "import_data", - "export_data", - "list_data_items", - "get_annotation_spec", - "list_annotations", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_dataset_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.DatasetServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_dataset_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.DatasetServiceTransport() - adc.assert_called_once() - - -def test_dataset_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - DatasetServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id=None, - ) - - -def test_dataset_service_transport_auth_adc(): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_dataset_service_host_no_port(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_dataset_service_host_with_port(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_dataset_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.DatasetServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.DatasetServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) -def test_dataset_service_transport_channel_mtls_with_client_cert_source( - transport_class, -): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.DatasetServiceGrpcTransport, - transports.DatasetServiceGrpcAsyncIOTransport, - ], -) -def test_dataset_service_transport_channel_mtls_with_adc(transport_class): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_dataset_service_grpc_lro_client(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_dataset_service_grpc_lro_async_client(): - client = DatasetServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", - ) - transport = client._client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_dataset_path(): - project = "squid" - location = "clam" - dataset = "whelk" - - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) - actual = DatasetServiceClient.dataset_path(project, location, dataset) - assert expected == actual - - -def test_parse_dataset_path(): - expected = { - "project": "octopus", - "location": "oyster", - "dataset": "nudibranch", - } - path = DatasetServiceClient.dataset_path(**expected) - - # Check that the path construction is reversible. - actual = DatasetServiceClient.parse_dataset_path(path) - assert expected == actual - - -def test_client_withDEFAULT_CLIENT_INFO(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object( - transports.DatasetServiceTransport, "_prep_wrapped_messages" - ) as prep: - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object( - transports.DatasetServiceTransport, "_prep_wrapped_messages" - ) as prep: - transport_class = DatasetServiceClient.get_transport_class() - transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py deleted file mode 100644 index d8e29265bb..0000000000 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ /dev/null @@ -1,2447 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import mock - -import grpc -from grpc.experimental import aio -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule - -from google import auth -from google.api_core import client_options -from google.api_core import exceptions -from google.api_core import future -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import operation_async # type: ignore -from google.api_core import operations_v1 -from google.auth import credentials -from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - EndpointServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - EndpointServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers -from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports -from google.cloud.aiplatform_v1beta1.types import accelerator_type -from google.cloud.aiplatform_v1beta1.types import endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint_service -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert EndpointServiceClient._get_default_mtls_endpoint(None) is None - assert ( - EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) - - -@pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient] -) -def test_endpoint_service_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_endpoint_service_client_get_transport_class(): - transport = EndpointServiceClient.get_transport_class() - assert transport == transports.EndpointServiceGrpcTransport - - transport = EndpointServiceClient.get_transport_class("grpc") - assert transport == transports.EndpointServiceGrpcTransport - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - EndpointServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceClient), -) -@mock.patch.object( - EndpointServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceAsyncClient), -) -def test_endpoint_service_client_client_options( - client_class, transport_class, transport_name -): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - EndpointServiceClient, - transports.EndpointServiceGrpcTransport, - "grpc", - "true", - ), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - EndpointServiceClient, - transports.EndpointServiceGrpcTransport, - "grpc", - "false", - ), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - EndpointServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceClient), -) -@mock.patch.object( - EndpointServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(EndpointServiceAsyncClient), -) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_endpoint_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) - - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_endpoint_service_client_client_options_scopes( - client_class, transport_class, transport_name -): - # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), - ( - EndpointServiceAsyncClient, - transports.EndpointServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_endpoint_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): - # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_endpoint_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = EndpointServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_create_endpoint( - transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest -): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == endpoint_service.CreateEndpointRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_endpoint_from_dict(): - test_create_endpoint(request_type=dict) - - -@pytest.mark.asyncio -async def test_create_endpoint_async(transport: str = "grpc_asyncio"): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.CreateEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_endpoint), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.create_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.CreateEndpointRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.create_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_create_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.CreateEndpointRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_endpoint), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.create_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_create_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - - -def test_create_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_endpoint( - endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), - ) - - -@pytest.mark.asyncio -async def test_create_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_endpoint), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - - -@pytest.mark.asyncio -async def test_create_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_endpoint( - endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), - ) - - -def test_get_endpoint( - transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest -): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - - response = client.get_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == endpoint_service.GetEndpointRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, endpoint.Endpoint) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.description == "description_value" - - assert response.etag == "etag_value" - - -def test_get_endpoint_from_dict(): - test_get_endpoint(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_endpoint_async(transport: str = "grpc_asyncio"): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.GetEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_endpoint), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - ) - - response = await client.get_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, endpoint.Endpoint) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.description == "description_value" - - assert response.etag == "etag_value" - - -def test_get_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.GetEndpointRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - call.return_value = endpoint.Endpoint() - - client.get_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.GetEndpointRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_endpoint), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) - - await client.get_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint.Endpoint() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_endpoint(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_endpoint), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint.Endpoint() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_endpoint(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", - ) - - -def test_list_endpoints( - transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest -): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_endpoints(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == endpoint_service.ListEndpointsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListEndpointsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_endpoints_from_dict(): - test_list_endpoints(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_endpoints_async(transport: str = "grpc_asyncio"): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.ListEndpointsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_endpoints), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_endpoints(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListEndpointsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_endpoints_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.ListEndpointsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - call.return_value = endpoint_service.ListEndpointsResponse() - - client.list_endpoints(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_endpoints_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.ListEndpointsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_endpoints), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse() - ) - - await client.list_endpoints(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_endpoints_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint_service.ListEndpointsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_endpoints(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_endpoints_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_endpoints_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_endpoints), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint_service.ListEndpointsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - endpoint_service.ListEndpointsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_endpoints(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_endpoints_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", - ) - - -def test_list_endpoints_pager(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - endpoint.Endpoint(), - ], - next_page_token="abc", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_endpoints(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in results) - - -def test_list_endpoints_pages(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - endpoint.Endpoint(), - ], - next_page_token="abc", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], - ), - RuntimeError, - ) - pages = list(client.list_endpoints(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_endpoints_async_pager(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_endpoints), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - endpoint.Endpoint(), - ], - next_page_token="abc", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], - ), - RuntimeError, - ) - async_pager = await client.list_endpoints(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in responses) - - -@pytest.mark.asyncio -async def test_list_endpoints_async_pages(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_endpoints), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - endpoint.Endpoint(), - ], - next_page_token="abc", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_endpoints(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_update_endpoint( - transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest -): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - - response = client.update_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == endpoint_service.UpdateEndpointRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_endpoint.Endpoint) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.description == "description_value" - - assert response.etag == "etag_value" - - -def test_update_endpoint_from_dict(): - test_update_endpoint(request_type=dict) - - -@pytest.mark.asyncio -async def test_update_endpoint_async(transport: str = "grpc_asyncio"): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.UpdateEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_endpoint), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - ) - - response = await client.update_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_endpoint.Endpoint) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.description == "description_value" - - assert response.etag == "etag_value" - - -def test_update_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = "endpoint.name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: - call.return_value = gca_endpoint.Endpoint() - - client.update_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ - "metadata" - ] - - -@pytest.mark.asyncio -async def test_update_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.UpdateEndpointRequest() - request.endpoint.name = "endpoint.name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_endpoint), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint() - ) - - await client.update_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ - "metadata" - ] - - -def test_update_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_endpoint.Endpoint() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_endpoint( - endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -@pytest.mark.asyncio -async def test_update_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_endpoint), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_endpoint.Endpoint() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_endpoint.Endpoint() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -@pytest.mark.asyncio -async def test_update_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.update_endpoint( - endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_delete_endpoint( - transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest -): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == endpoint_service.DeleteEndpointRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_endpoint_from_dict(): - test_delete_endpoint(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_endpoint_async(transport: str = "grpc_asyncio"): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.DeleteEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_endpoint), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.DeleteEndpointRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_endpoint_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.DeleteEndpointRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_endpoint), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_endpoint(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_endpoint_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_endpoint), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_endpoint(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_endpoint_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", - ) - - -def test_deploy_model( - transport: str = "grpc", request_type=endpoint_service.DeployModelRequest -): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.deploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == endpoint_service.DeployModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_deploy_model_from_dict(): - test_deploy_model(request_type=dict) - - -@pytest.mark.asyncio -async def test_deploy_model_async(transport: str = "grpc_asyncio"): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.DeployModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.deploy_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.deploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_deploy_model_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.DeployModelRequest() - request.endpoint = "endpoint/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.deploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_deploy_model_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.DeployModelRequest() - request.endpoint = "endpoint/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.deploy_model), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.deploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] - - -def test_deploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == "endpoint_value" - - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) - - assert args[0].traffic_split == {"key_value": 541} - - -def test_deploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.deploy_model( - endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, - ) - - -@pytest.mark.asyncio -async def test_deploy_model_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.deploy_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == "endpoint_value" - - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) - - assert args[0].traffic_split == {"key_value": 541} - - -@pytest.mark.asyncio -async def test_deploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.deploy_model( - endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, - ) - - -def test_undeploy_model( - transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest -): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.undeploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == endpoint_service.UndeployModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_undeploy_model_from_dict(): - test_undeploy_model(request_type=dict) - - -@pytest.mark.asyncio -async def test_undeploy_model_async(transport: str = "grpc_asyncio"): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.UndeployModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.undeploy_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.undeploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_undeploy_model_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.UndeployModelRequest() - request.endpoint = "endpoint/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.undeploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_undeploy_model_field_headers_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.UndeployModelRequest() - request.endpoint = "endpoint/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.undeploy_model), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.undeploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] - - -def test_undeploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == "endpoint_value" - - assert args[0].deployed_model_id == "deployed_model_id_value" - - assert args[0].traffic_split == {"key_value": 541} - - -def test_undeploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.undeploy_model( - endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, - ) - - -@pytest.mark.asyncio -async def test_undeploy_model_flattened_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.undeploy_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == "endpoint_value" - - assert args[0].deployed_model_id == "deployed_model_id_value" - - assert args[0].traffic_split == {"key_value": 541} - - -@pytest.mark.asyncio -async def test_undeploy_model_flattened_error_async(): - client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.undeploy_model( - endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = EndpointServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = EndpointServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = EndpointServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.EndpointServiceGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.EndpointServiceGrpcTransport,) - - -def test_endpoint_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): - transport = transports.EndpointServiceTransport( - credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", - ) - - -def test_endpoint_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.EndpointServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_endpoint", - "get_endpoint", - "list_endpoints", - "update_endpoint", - "delete_endpoint", - "deploy_model", - "undeploy_model", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_endpoint_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.EndpointServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_endpoint_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.EndpointServiceTransport() - adc.assert_called_once() - - -def test_endpoint_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - EndpointServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id=None, - ) - - -def test_endpoint_service_transport_auth_adc(): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_endpoint_service_host_no_port(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_endpoint_service_host_with_port(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_endpoint_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.EndpointServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.EndpointServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) -def test_endpoint_service_transport_channel_mtls_with_client_cert_source( - transport_class, -): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.EndpointServiceGrpcTransport, - transports.EndpointServiceGrpcAsyncIOTransport, - ], -) -def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_endpoint_service_grpc_lro_client(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_endpoint_service_grpc_lro_async_client(): - client = EndpointServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", - ) - transport = client._client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_endpoint_path(): - project = "squid" - location = "clam" - endpoint = "whelk" - - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) - actual = EndpointServiceClient.endpoint_path(project, location, endpoint) - assert expected == actual - - -def test_parse_endpoint_path(): - expected = { - "project": "octopus", - "location": "oyster", - "endpoint": "nudibranch", - } - path = EndpointServiceClient.endpoint_path(**expected) - - # Check that the path construction is reversible. - actual = EndpointServiceClient.parse_endpoint_path(path) - assert expected == actual - - -def test_client_withDEFAULT_CLIENT_INFO(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object( - transports.EndpointServiceTransport, "_prep_wrapped_messages" - ) as prep: - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object( - transports.EndpointServiceTransport, "_prep_wrapped_messages" - ) as prep: - transport_class = EndpointServiceClient.get_transport_class() - transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py deleted file mode 100644 index ef4bd521bb..0000000000 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ /dev/null @@ -1,5808 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import mock - -import grpc -from grpc.experimental import aio -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule - -from google import auth -from google.api_core import client_options -from google.api_core import exceptions -from google.api_core import future -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import operation_async # type: ignore -from google.api_core import operations_v1 -from google.auth import credentials -from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceAsyncClient -from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient -from google.cloud.aiplatform_v1beta1.services.job_service import pagers -from google.cloud.aiplatform_v1beta1.services.job_service import transports -from google.cloud.aiplatform_v1beta1.types import accelerator_type -from google.cloud.aiplatform_v1beta1.types import ( - accelerator_type as gca_accelerator_type, -) -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) -from google.cloud.aiplatform_v1beta1.types import completion_stats -from google.cloud.aiplatform_v1beta1.types import ( - completion_stats as gca_completion_stats, -) -from google.cloud.aiplatform_v1beta1.types import custom_job -from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import job_service -from google.cloud.aiplatform_v1beta1.types import job_state -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters -from google.cloud.aiplatform_v1beta1.types import ( - manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, -) -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import study -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import any_pb2 as gp_any # type: ignore -from google.protobuf import duration_pb2 as duration # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.rpc import status_pb2 as status # type: ignore -from google.type import money_pb2 as money # type: ignore - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert JobServiceClient._get_default_mtls_endpoint(None) is None - assert ( - JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - - -@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) -def test_job_service_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_job_service_client_get_transport_class(): - transport = JobServiceClient.get_transport_class() - assert transport == transports.JobServiceGrpcTransport - - transport = JobServiceClient.get_transport_class("grpc") - assert transport == transports.JobServiceGrpcTransport - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) -) -@mock.patch.object( - JobServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(JobServiceAsyncClient), -) -def test_job_service_client_client_options( - client_class, transport_class, transport_name -): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) -) -@mock.patch.object( - JobServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(JobServiceAsyncClient), -) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_job_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) - - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_job_service_client_client_options_scopes( - client_class, transport_class, transport_name -): - # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), - ( - JobServiceAsyncClient, - transports.JobServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_job_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): - # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_job_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_create_custom_job( - transport: str = "grpc", request_type=job_service.CreateCustomJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.CreateCustomJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_custom_job.CustomJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_custom_job_from_dict(): - test_create_custom_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_create_custom_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) - - response = await client.create_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_custom_job.CustomJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CreateCustomJobRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_custom_job), "__call__" - ) as call: - call.return_value = gca_custom_job.CustomJob() - - client.create_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_create_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CreateCustomJobRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob() - ) - - await client.create_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_create_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_custom_job.CustomJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") - - -def test_create_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_custom_job( - job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), - ) - - -@pytest.mark.asyncio -async def test_create_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_custom_job.CustomJob() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_custom_job.CustomJob() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") - - -@pytest.mark.asyncio -async def test_create_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_custom_job( - job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), - ) - - -def test_get_custom_job( - transport: str = "grpc", request_type=job_service.GetCustomJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.GetCustomJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, custom_job.CustomJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_custom_job_from_dict(): - test_get_custom_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_custom_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) - - response = await client.get_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, custom_job.CustomJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetCustomJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - call.return_value = custom_job.CustomJob() - - client.get_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetCustomJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob() - ) - - await client.get_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = custom_job.CustomJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = custom_job.CustomJob() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - custom_job.CustomJob() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", - ) - - -def test_list_custom_jobs( - transport: str = "grpc", request_type=job_service.ListCustomJobsRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListCustomJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_custom_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.ListCustomJobsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListCustomJobsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_custom_jobs_from_dict(): - test_list_custom_jobs(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_custom_jobs_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListCustomJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_custom_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) - ) - - response = await client.list_custom_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListCustomJobsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_custom_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListCustomJobsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - call.return_value = job_service.ListCustomJobsResponse() - - client.list_custom_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_custom_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListCustomJobsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_custom_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse() - ) - - await client.list_custom_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_custom_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListCustomJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_custom_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_custom_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_custom_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_custom_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListCustomJobsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListCustomJobsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_custom_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_custom_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", - ) - - -def test_list_custom_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - custom_job.CustomJob(), - ], - next_page_token="abc", - ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", - ), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_custom_jobs(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in results) - - -def test_list_custom_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - custom_job.CustomJob(), - ], - next_page_token="abc", - ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", - ), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], - ), - RuntimeError, - ) - pages = list(client.list_custom_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_custom_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_custom_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - custom_job.CustomJob(), - ], - next_page_token="abc", - ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", - ), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], - ), - RuntimeError, - ) - async_pager = await client.list_custom_jobs(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in responses) - - -@pytest.mark.asyncio -async def test_list_custom_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_custom_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - custom_job.CustomJob(), - ], - next_page_token="abc", - ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", - ), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_custom_jobs(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_delete_custom_job( - transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.DeleteCustomJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_custom_job_from_dict(): - test_delete_custom_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_custom_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.DeleteCustomJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_custom_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.DeleteCustomJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", - ) - - -def test_cancel_custom_job( - transport: str = "grpc", request_type=job_service.CancelCustomJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.CancelCustomJobRequest() - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_custom_job_from_dict(): - test_cancel_custom_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_cancel_custom_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - response = await client.cancel_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CancelCustomJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_custom_job), "__call__" - ) as call: - call.return_value = None - - client.cancel_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_cancel_custom_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CancelCustomJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_custom_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - await client.cancel_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_cancel_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.cancel_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_cancel_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_cancel_custom_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.cancel_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_cancel_custom_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", - ) - - -def test_create_data_labeling_job( - transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - - response = client.create_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.CreateDataLabelingJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.datasets == ["datasets_value"] - - assert response.labeler_count == 1375 - - assert response.instruction_uri == "instruction_uri_value" - - assert response.inputs_schema_uri == "inputs_schema_uri_value" - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - assert response.labeling_progress == 1810 - - assert response.specialist_pools == ["specialist_pools_value"] - - -def test_create_data_labeling_job_from_dict(): - test_create_data_labeling_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_create_data_labeling_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - ) - - response = await client.create_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.datasets == ["datasets_value"] - - assert response.labeler_count == 1375 - - assert response.instruction_uri == "instruction_uri_value" - - assert response.inputs_schema_uri == "inputs_schema_uri_value" - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - assert response.labeling_progress == 1810 - - assert response.specialist_pools == ["specialist_pools_value"] - - -def test_create_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CreateDataLabelingJobRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_data_labeling_job), "__call__" - ) as call: - call.return_value = gca_data_labeling_job.DataLabelingJob() - - client.create_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_create_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CreateDataLabelingJobRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob() - ) - - await client.create_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_create_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_data_labeling_job.DataLabelingJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) - - -def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_data_labeling_job( - job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), - ) - - -@pytest.mark.asyncio -async def test_create_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_data_labeling_job.DataLabelingJob() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_data_labeling_job.DataLabelingJob() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) - - -@pytest.mark.asyncio -async def test_create_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_data_labeling_job( - job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), - ) - - -def test_get_data_labeling_job( - transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - - response = client.get_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.GetDataLabelingJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, data_labeling_job.DataLabelingJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.datasets == ["datasets_value"] - - assert response.labeler_count == 1375 - - assert response.instruction_uri == "instruction_uri_value" - - assert response.inputs_schema_uri == "inputs_schema_uri_value" - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - assert response.labeling_progress == 1810 - - assert response.specialist_pools == ["specialist_pools_value"] - - -def test_get_data_labeling_job_from_dict(): - test_get_data_labeling_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_data_labeling_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - ) - - response = await client.get_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, data_labeling_job.DataLabelingJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.datasets == ["datasets_value"] - - assert response.labeler_count == 1375 - - assert response.instruction_uri == "instruction_uri_value" - - assert response.inputs_schema_uri == "inputs_schema_uri_value" - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - assert response.labeling_progress == 1810 - - assert response.specialist_pools == ["specialist_pools_value"] - - -def test_get_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetDataLabelingJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - call.return_value = data_labeling_job.DataLabelingJob() - - client.get_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetDataLabelingJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob() - ) - - await client.get_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = data_labeling_job.DataLabelingJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = data_labeling_job.DataLabelingJob() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - data_labeling_job.DataLabelingJob() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", - ) - - -def test_list_data_labeling_jobs( - transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_data_labeling_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.ListDataLabelingJobsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataLabelingJobsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_data_labeling_jobs_from_dict(): - test_list_data_labeling_jobs(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_data_labeling_jobs_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListDataLabelingJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_data_labeling_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListDataLabelingJobsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - call.return_value = job_service.ListDataLabelingJobsResponse() - - client.list_data_labeling_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_data_labeling_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListDataLabelingJobsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_labeling_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse() - ) - - await client.list_data_labeling_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListDataLabelingJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_data_labeling_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_data_labeling_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListDataLabelingJobsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListDataLabelingJobsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_data_labeling_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_data_labeling_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", - ) - - -def test_list_data_labeling_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - next_page_token="abc", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_data_labeling_jobs(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) - - -def test_list_data_labeling_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - next_page_token="abc", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_data_labeling_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_data_labeling_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_labeling_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - next_page_token="abc", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - ), - RuntimeError, - ) - async_pager = await client.list_data_labeling_jobs(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) - - -@pytest.mark.asyncio -async def test_list_data_labeling_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_data_labeling_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - next_page_token="abc", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_data_labeling_jobs(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_delete_data_labeling_job( - transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.DeleteDataLabelingJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_data_labeling_job_from_dict(): - test_delete_data_labeling_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_data_labeling_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.DeleteDataLabelingJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_data_labeling_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.DeleteDataLabelingJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", - ) - - -def test_cancel_data_labeling_job( - transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.CancelDataLabelingJobRequest() - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_data_labeling_job_from_dict(): - test_cancel_data_labeling_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_cancel_data_labeling_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - response = await client.cancel_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CancelDataLabelingJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_data_labeling_job), "__call__" - ) as call: - call.return_value = None - - client.cancel_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_cancel_data_labeling_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CancelDataLabelingJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_data_labeling_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - await client.cancel_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.cancel_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_cancel_data_labeling_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.cancel_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_cancel_data_labeling_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", - ) - - -def test_create_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CreateHyperparameterTuningJobRequest, -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.CreateHyperparameterTuningJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.max_trial_count == 1609 - - assert response.parallel_trial_count == 2128 - - assert response.max_failed_trial_count == 2317 - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_hyperparameter_tuning_job_from_dict(): - test_create_hyperparameter_tuning_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) - - response = await client.create_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.max_trial_count == 1609 - - assert response.parallel_trial_count == 2128 - - assert response.max_failed_trial_count == 2317 - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - - client.create_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CreateHyperparameterTuningJobRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob() - ) - - await client.create_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) - - -def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_hyperparameter_tuning_job( - job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), - ) - - -@pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_hyperparameter_tuning_job.HyperparameterTuningJob() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) - - -@pytest.mark.asyncio -async def test_create_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_hyperparameter_tuning_job( - job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), - ) - - -def test_get_hyperparameter_tuning_job( - transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.GetHyperparameterTuningJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.max_trial_count == 1609 - - assert response.parallel_trial_count == 2128 - - assert response.max_failed_trial_count == 2317 - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_hyperparameter_tuning_job_from_dict(): - test_get_hyperparameter_tuning_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) - - response = await client.get_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.max_trial_count == 1609 - - assert response.parallel_trial_count == 2128 - - assert response.max_failed_trial_count == 2317 - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetHyperparameterTuningJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - - client.get_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetHyperparameterTuningJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob() - ) - - await client.get_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - hyperparameter_tuning_job.HyperparameterTuningJob() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_list_hyperparameter_tuning_jobs( - transport: str = "grpc", - request_type=job_service.ListHyperparameterTuningJobsRequest, -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_hyperparameter_tuning_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.ListHyperparameterTuningJobsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_hyperparameter_tuning_jobs_from_dict(): - test_list_hyperparameter_tuning_jobs(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListHyperparameterTuningJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_hyperparameter_tuning_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - call.return_value = job_service.ListHyperparameterTuningJobsResponse() - - client.list_hyperparameter_tuning_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListHyperparameterTuningJobsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse() - ) - - await client.list_hyperparameter_tuning_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListHyperparameterTuningJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_hyperparameter_tuning_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListHyperparameterTuningJobsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListHyperparameterTuningJobsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", - ) - - -def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="abc", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="ghi", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_hyperparameter_tuning_jobs(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results - ) - - -def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="abc", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="ghi", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_hyperparameter_tuning_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="abc", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="ghi", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - ), - RuntimeError, - ) - async_pager = await client.list_hyperparameter_tuning_jobs(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in responses - ) - - -@pytest.mark.asyncio -async def test_list_hyperparameter_tuning_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_hyperparameter_tuning_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="abc", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="ghi", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - ), - RuntimeError, - ) - pages = [] - async for page_ in ( - await client.list_hyperparameter_tuning_jobs(request={}) - ).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_delete_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.DeleteHyperparameterTuningJobRequest, -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_hyperparameter_tuning_job_from_dict(): - test_delete_hyperparameter_tuning_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.DeleteHyperparameterTuningJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_cancel_hyperparameter_tuning_job( - transport: str = "grpc", - request_type=job_service.CancelHyperparameterTuningJobRequest, -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.CancelHyperparameterTuningJobRequest() - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_hyperparameter_tuning_job_from_dict(): - test_cancel_hyperparameter_tuning_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - response = await client.cancel_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CancelHyperparameterTuningJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = None - - client.cancel_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CancelHyperparameterTuningJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - await client.cancel_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.cancel_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.cancel_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_create_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.CreateBatchPredictionJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.model == "model_value" - - assert response.generate_explanation is True - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_batch_prediction_job_from_dict(): - test_create_batch_prediction_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_create_batch_prediction_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) - - response = await client.create_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.model == "model_value" - - assert response.generate_explanation is True - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CreateBatchPredictionJobRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_batch_prediction_job), "__call__" - ) as call: - call.return_value = gca_batch_prediction_job.BatchPredictionJob() - - client.create_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_create_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CreateBatchPredictionJobRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob() - ) - - await client.create_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_create_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_batch_prediction_job.BatchPredictionJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) - - -def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_batch_prediction_job( - job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), - ) - - -@pytest.mark.asyncio -async def test_create_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_batch_prediction_job.BatchPredictionJob() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_batch_prediction_job.BatchPredictionJob() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) - - -@pytest.mark.asyncio -async def test_create_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_batch_prediction_job( - job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), - ) - - -def test_get_batch_prediction_job( - transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.GetBatchPredictionJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, batch_prediction_job.BatchPredictionJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.model == "model_value" - - assert response.generate_explanation is True - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_batch_prediction_job_from_dict(): - test_get_batch_prediction_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_batch_prediction_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - ) - - response = await client.get_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, batch_prediction_job.BatchPredictionJob) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.model == "model_value" - - assert response.generate_explanation is True - - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetBatchPredictionJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - call.return_value = batch_prediction_job.BatchPredictionJob() - - client.get_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetBatchPredictionJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob() - ) - - await client.get_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = batch_prediction_job.BatchPredictionJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = batch_prediction_job.BatchPredictionJob() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - batch_prediction_job.BatchPredictionJob() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", - ) - - -def test_list_batch_prediction_jobs( - transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_batch_prediction_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.ListBatchPredictionJobsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListBatchPredictionJobsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_batch_prediction_jobs_from_dict(): - test_list_batch_prediction_jobs(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListBatchPredictionJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_batch_prediction_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListBatchPredictionJobsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - call.return_value = job_service.ListBatchPredictionJobsResponse() - - client.list_batch_prediction_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_batch_prediction_jobs_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListBatchPredictionJobsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse() - ) - - await client.list_batch_prediction_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListBatchPredictionJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_batch_prediction_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_batch_prediction_jobs_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListBatchPredictionJobsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - job_service.ListBatchPredictionJobsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_batch_prediction_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_batch_prediction_jobs_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", - ) - - -def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_batch_prediction_jobs(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results - ) - - -def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async_pager(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_batch_prediction_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, - ) - async_pager = await client.list_batch_prediction_jobs(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses - ) - - -@pytest.mark.asyncio -async def test_list_batch_prediction_jobs_async_pages(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_batch_prediction_jobs), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_delete_batch_prediction_job( - transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.DeleteBatchPredictionJobRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_batch_prediction_job_from_dict(): - test_delete_batch_prediction_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_batch_prediction_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.DeleteBatchPredictionJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_batch_prediction_job), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.DeleteBatchPredictionJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", - ) - - -def test_cancel_batch_prediction_job( - transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest -): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == job_service.CancelBatchPredictionJobRequest() - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_batch_prediction_job_from_dict(): - test_cancel_batch_prediction_job(request_type=dict) - - -@pytest.mark.asyncio -async def test_cancel_batch_prediction_job_async(transport: str = "grpc_asyncio"): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - response = await client.cancel_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CancelBatchPredictionJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - call.return_value = None - - client.cancel_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_cancel_batch_prediction_job_field_headers_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.CancelBatchPredictionJobRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - await client.cancel_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.cancel_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_cancel_batch_prediction_job_flattened_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.cancel_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_cancel_batch_prediction_job_flattened_error_async(): - client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = JobServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = JobServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = JobServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.JobServiceGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.JobServiceGrpcTransport,) - - -def test_job_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): - transport = transports.JobServiceTransport( - credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", - ) - - -def test_job_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.JobServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_custom_job", - "get_custom_job", - "list_custom_jobs", - "delete_custom_job", - "cancel_custom_job", - "create_data_labeling_job", - "get_data_labeling_job", - "list_data_labeling_jobs", - "delete_data_labeling_job", - "cancel_data_labeling_job", - "create_hyperparameter_tuning_job", - "get_hyperparameter_tuning_job", - "list_hyperparameter_tuning_jobs", - "delete_hyperparameter_tuning_job", - "cancel_hyperparameter_tuning_job", - "create_batch_prediction_job", - "get_batch_prediction_job", - "list_batch_prediction_jobs", - "delete_batch_prediction_job", - "cancel_batch_prediction_job", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_job_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.JobServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_job_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.JobServiceTransport() - adc.assert_called_once() - - -def test_job_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - JobServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id=None, - ) - - -def test_job_service_transport_auth_adc(): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.JobServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_job_service_host_no_port(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_job_service_host_with_port(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_job_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.JobServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.JobServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@pytest.mark.parametrize( - "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], -) -def test_job_service_transport_channel_mtls_with_adc(transport_class): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_job_service_grpc_lro_client(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_job_service_grpc_lro_async_client(): - client = JobServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", - ) - transport = client._client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_batch_prediction_job_path(): - project = "squid" - location = "clam" - batch_prediction_job = "whelk" - - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, location=location, batch_prediction_job=batch_prediction_job, - ) - actual = JobServiceClient.batch_prediction_job_path( - project, location, batch_prediction_job - ) - assert expected == actual - - -def test_parse_batch_prediction_job_path(): - expected = { - "project": "octopus", - "location": "oyster", - "batch_prediction_job": "nudibranch", - } - path = JobServiceClient.batch_prediction_job_path(**expected) - - # Check that the path construction is reversible. - actual = JobServiceClient.parse_batch_prediction_job_path(path) - assert expected == actual - - -def test_custom_job_path(): - project = "squid" - location = "clam" - custom_job = "whelk" - - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) - actual = JobServiceClient.custom_job_path(project, location, custom_job) - assert expected == actual - - -def test_parse_custom_job_path(): - expected = { - "project": "octopus", - "location": "oyster", - "custom_job": "nudibranch", - } - path = JobServiceClient.custom_job_path(**expected) - - # Check that the path construction is reversible. - actual = JobServiceClient.parse_custom_job_path(path) - assert expected == actual - - -def test_data_labeling_job_path(): - project = "squid" - location = "clam" - data_labeling_job = "whelk" - - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) - actual = JobServiceClient.data_labeling_job_path( - project, location, data_labeling_job - ) - assert expected == actual - - -def test_parse_data_labeling_job_path(): - expected = { - "project": "octopus", - "location": "oyster", - "data_labeling_job": "nudibranch", - } - path = JobServiceClient.data_labeling_job_path(**expected) - - # Check that the path construction is reversible. - actual = JobServiceClient.parse_data_labeling_job_path(path) - assert expected == actual - - -def test_hyperparameter_tuning_job_path(): - project = "squid" - location = "clam" - hyperparameter_tuning_job = "whelk" - - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) - actual = JobServiceClient.hyperparameter_tuning_job_path( - project, location, hyperparameter_tuning_job - ) - assert expected == actual - - -def test_parse_hyperparameter_tuning_job_path(): - expected = { - "project": "octopus", - "location": "oyster", - "hyperparameter_tuning_job": "nudibranch", - } - path = JobServiceClient.hyperparameter_tuning_job_path(**expected) - - # Check that the path construction is reversible. - actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) - assert expected == actual - - -def test_client_withDEFAULT_CLIENT_INFO(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object( - transports.JobServiceTransport, "_prep_wrapped_messages" - ) as prep: - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object( - transports.JobServiceTransport, "_prep_wrapped_messages" - ) as prep: - transport_class = JobServiceClient.get_transport_class() - transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py deleted file mode 100644 index 897ecb0c59..0000000000 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ /dev/null @@ -1,3415 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import mock - -import grpc -from grpc.experimental import aio -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule - -from google import auth -from google.api_core import client_options -from google.api_core import exceptions -from google.api_core import future -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import operation_async # type: ignore -from google.api_core import operations_v1 -from google.auth import credentials -from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.model_service import ( - ModelServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient -from google.cloud.aiplatform_v1beta1.services.model_service import pagers -from google.cloud.aiplatform_v1beta1.services.model_service import transports -from google.cloud.aiplatform_v1beta1.types import deployed_model_ref -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import model_evaluation -from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice -from google.cloud.aiplatform_v1beta1.types import model_service -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ( - ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - - -@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) -def test_model_service_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_model_service_client_get_transport_class(): - transport = ModelServiceClient.get_transport_class() - assert transport == transports.ModelServiceGrpcTransport - - transport = ModelServiceClient.get_transport_class("grpc") - assert transport == transports.ModelServiceGrpcTransport - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) -) -@mock.patch.object( - ModelServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(ModelServiceAsyncClient), -) -def test_model_service_client_client_options( - client_class, transport_class, transport_name -): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) -) -@mock.patch.object( - ModelServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(ModelServiceAsyncClient), -) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) - - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_model_service_client_client_options_scopes( - client_class, transport_class, transport_name -): - # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - ( - ModelServiceAsyncClient, - transports.ModelServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_model_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): - # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_model_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_upload_model( - transport: str = "grpc", request_type=model_service.UploadModelRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.upload_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.upload_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.UploadModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_upload_model_from_dict(): - test_upload_model(request_type=dict) - - -@pytest.mark.asyncio -async def test_upload_model_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.UploadModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.upload_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.upload_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_upload_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.UploadModelRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.upload_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.upload_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_upload_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.UploadModelRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.upload_model), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.upload_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_upload_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.upload_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].model == gca_model.Model(name="name_value") - - -def test_upload_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.upload_model( - model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), - ) - - -@pytest.mark.asyncio -async def test_upload_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.upload_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].model == gca_model.Model(name="name_value") - - -@pytest.mark.asyncio -async def test_upload_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.upload_model( - model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), - ) - - -def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", - ) - - response = client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.GetModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, model.Model) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.description == "description_value" - - assert response.metadata_schema_uri == "metadata_schema_uri_value" - - assert response.training_pipeline == "training_pipeline_value" - - assert response.artifact_uri == "artifact_uri_value" - - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] - - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] - - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - - assert response.etag == "etag_value" - - -def test_get_model_from_dict(): - test_get_model(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_model_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=[ - "supported_input_storage_formats_value" - ], - supported_output_storage_formats=[ - "supported_output_storage_formats_value" - ], - etag="etag_value", - ) - ) - - response = await client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model.Model) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.description == "description_value" - - assert response.metadata_schema_uri == "metadata_schema_uri_value" - - assert response.training_pipeline == "training_pipeline_value" - - assert response.artifact_uri == "artifact_uri_value" - - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] - - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] - - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - - assert response.etag == "etag_value" - - -def test_get_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - call.return_value = model.Model() - - client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) - - await client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_model(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model( - model_service.GetModelRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_model(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_model( - model_service.GetModelRequest(), name="name_value", - ) - - -def test_list_models( - transport: str = "grpc", request_type=model_service.ListModelsRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.ListModelsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_models_from_dict(): - test_list_models(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_models_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_models), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse(next_page_token="next_page_token_value",) - ) - - response = await client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_models_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - call.return_value = model_service.ListModelsResponse() - - client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_models_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_models), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse() - ) - - await client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_models_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_models(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_models_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_models( - model_service.ListModelsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_models), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_models(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_models( - model_service.ListModelsRequest(), parent="parent_value", - ) - - -def test_list_models_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", - ), - model_service.ListModelsResponse(models=[], next_page_token="def",), - model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", - ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_models(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, model.Model) for i in results) - - -def test_list_models_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", - ), - model_service.ListModelsResponse(models=[], next_page_token="def",), - model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", - ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), - RuntimeError, - ) - pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_models_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_models), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", - ), - model_service.ListModelsResponse(models=[], next_page_token="def",), - model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", - ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), - RuntimeError, - ) - async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, model.Model) for i in responses) - - -@pytest.mark.asyncio -async def test_list_models_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_models), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", - ), - model_service.ListModelsResponse(models=[], next_page_token="def",), - model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", - ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_models(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_update_model( - transport: str = "grpc", request_type=model_service.UpdateModelRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", - ) - - response = client.update_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.UpdateModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_model.Model) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.description == "description_value" - - assert response.metadata_schema_uri == "metadata_schema_uri_value" - - assert response.training_pipeline == "training_pipeline_value" - - assert response.artifact_uri == "artifact_uri_value" - - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] - - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] - - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - - assert response.etag == "etag_value" - - -def test_update_model_from_dict(): - test_update_model(request_type=dict) - - -@pytest.mark.asyncio -async def test_update_model_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.UpdateModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=[ - "supported_input_storage_formats_value" - ], - supported_output_storage_formats=[ - "supported_output_storage_formats_value" - ], - etag="etag_value", - ) - ) - - response = await client.update_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_model.Model) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.description == "description_value" - - assert response.metadata_schema_uri == "metadata_schema_uri_value" - - assert response.training_pipeline == "training_pipeline_value" - - assert response.artifact_uri == "artifact_uri_value" - - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] - - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] - - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - - assert response.etag == "etag_value" - - -def test_update_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.UpdateModelRequest() - request.model.name = "model.name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_model), "__call__") as call: - call.return_value = gca_model.Model() - - client.update_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_update_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.UpdateModelRequest() - request.model.name = "model.name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_model), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) - - await client.update_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] - - -def test_update_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_model.Model() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].model == gca_model.Model(name="name_value") - - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_model( - model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -@pytest.mark.asyncio -async def test_update_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_model.Model() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].model == gca_model.Model(name="name_value") - - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -@pytest.mark.asyncio -async def test_update_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.update_model( - model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_delete_model( - transport: str = "grpc", request_type=model_service.DeleteModelRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.DeleteModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_model_from_dict(): - test_delete_model(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_model_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.DeleteModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.DeleteModelRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.DeleteModelRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_model), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_model(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_model( - model_service.DeleteModelRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_model(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_model( - model_service.DeleteModelRequest(), name="name_value", - ) - - -def test_export_model( - transport: str = "grpc", request_type=model_service.ExportModelRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.export_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.ExportModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_export_model_from_dict(): - test_export_model(request_type=dict) - - -@pytest.mark.asyncio -async def test_export_model_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ExportModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.export_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.export_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_export_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ExportModelRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_model), "__call__") as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.export_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_export_model_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ExportModelRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.export_model), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.export_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_export_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) - - -def test_export_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.export_model( - model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), - ) - - -@pytest.mark.asyncio -async def test_export_model_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.export_model), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) - - -@pytest.mark.asyncio -async def test_export_model_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.export_model( - model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), - ) - - -def test_get_model_evaluation( - transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], - ) - - response = client.get_model_evaluation(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.GetModelEvaluationRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, model_evaluation.ModelEvaluation) - - assert response.name == "name_value" - - assert response.metrics_schema_uri == "metrics_schema_uri_value" - - assert response.slice_dimensions == ["slice_dimensions_value"] - - -def test_get_model_evaluation_from_dict(): - test_get_model_evaluation(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_model_evaluation_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelEvaluationRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model_evaluation), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], - ) - ) - - response = await client.get_model_evaluation(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model_evaluation.ModelEvaluation) - - assert response.name == "name_value" - - assert response.metrics_schema_uri == "metrics_schema_uri_value" - - assert response.slice_dimensions == ["slice_dimensions_value"] - - -def test_get_model_evaluation_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelEvaluationRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - call.return_value = model_evaluation.ModelEvaluation() - - client.get_model_evaluation(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_model_evaluation_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelEvaluationRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model_evaluation), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation() - ) - - await client.get_model_evaluation(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_evaluation_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation.ModelEvaluation() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_model_evaluation(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_model_evaluation_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model_evaluation), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation.ModelEvaluation() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation.ModelEvaluation() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_model_evaluation(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_model_evaluation_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", - ) - - -def test_list_model_evaluations( - transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_model_evaluations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.ListModelEvaluationsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelEvaluationsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_model_evaluations_from_dict(): - test_list_model_evaluations(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_model_evaluations_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelEvaluationsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_model_evaluations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_model_evaluations_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelEvaluationsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - call.return_value = model_service.ListModelEvaluationsResponse() - - client.list_model_evaluations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_model_evaluations_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelEvaluationsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluations), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse() - ) - - await client.list_model_evaluations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_model_evaluations_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_model_evaluations(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_model_evaluations_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_model_evaluations(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_model_evaluations_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", - ) - - -def test_list_model_evaluations_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_model_evaluations(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) - - -def test_list_model_evaluations_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - ), - RuntimeError, - ) - pages = list(client.list_model_evaluations(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_model_evaluations_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - ), - RuntimeError, - ) - async_pager = await client.list_model_evaluations(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) - - -@pytest.mark.asyncio -async def test_list_model_evaluations_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluations), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_model_evaluations(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_get_model_evaluation_slice( - transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", - ) - - response = client.get_model_evaluation_slice(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.GetModelEvaluationSliceRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - - assert response.name == "name_value" - - assert response.metrics_schema_uri == "metrics_schema_uri_value" - - -def test_get_model_evaluation_slice_from_dict(): - test_get_model_evaluation_slice(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_model_evaluation_slice_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelEvaluationSliceRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model_evaluation_slice), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", - ) - ) - - response = await client.get_model_evaluation_slice(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - - assert response.name == "name_value" - - assert response.metrics_schema_uri == "metrics_schema_uri_value" - - -def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelEvaluationSliceRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - call.return_value = model_evaluation_slice.ModelEvaluationSlice() - - client.get_model_evaluation_slice(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_model_evaluation_slice_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelEvaluationSliceRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model_evaluation_slice), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice() - ) - - await client.get_model_evaluation_slice(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation_slice.ModelEvaluationSlice() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_model_evaluation_slice(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_model_evaluation_slice_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_model_evaluation_slice), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation_slice.ModelEvaluationSlice() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_evaluation_slice.ModelEvaluationSlice() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_model_evaluation_slice(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_model_evaluation_slice_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", - ) - - -def test_list_model_evaluation_slices( - transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest -): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_model_evaluation_slices(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == model_service.ListModelEvaluationSlicesRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_model_evaluation_slices_from_dict(): - test_list_model_evaluation_slices(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_model_evaluation_slices_async(transport: str = "grpc_asyncio"): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelEvaluationSlicesRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_model_evaluation_slices(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelEvaluationSlicesRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - call.return_value = model_service.ListModelEvaluationSlicesResponse() - - client.list_model_evaluation_slices(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_model_evaluation_slices_field_headers_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelEvaluationSlicesRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluation_slices), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse() - ) - - await client.list_model_evaluation_slices(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationSlicesResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_model_evaluation_slices(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_model_evaluation_slices_flattened_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationSlicesResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - model_service.ListModelEvaluationSlicesResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_model_evaluation_slices(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_model_evaluation_slices_flattened_error_async(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", - ) - - -def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="ghi", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_model_evaluation_slices(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results - ) - - -def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="ghi", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - ), - RuntimeError, - ) - pages = list(client.list_model_evaluation_slices(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_model_evaluation_slices_async_pager(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluation_slices), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="ghi", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - ), - RuntimeError, - ) - async_pager = await client.list_model_evaluation_slices(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) - for i in responses - ) - - -@pytest.mark.asyncio -async def test_list_model_evaluation_slices_async_pages(): - client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_model_evaluation_slices), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="ghi", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - ), - RuntimeError, - ) - pages = [] - async for page_ in ( - await client.list_model_evaluation_slices(request={}) - ).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = ModelServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = ModelServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.ModelServiceGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.ModelServiceGrpcTransport,) - - -def test_model_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): - transport = transports.ModelServiceTransport( - credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", - ) - - -def test_model_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.ModelServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "upload_model", - "get_model", - "list_models", - "update_model", - "delete_model", - "export_model", - "get_model_evaluation", - "list_model_evaluations", - "get_model_evaluation_slice", - "list_model_evaluation_slices", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_model_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.ModelServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_model_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.ModelServiceTransport() - adc.assert_called_once() - - -def test_model_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - ModelServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id=None, - ) - - -def test_model_service_transport_auth_adc(): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_model_service_host_no_port(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_model_service_host_with_port(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_model_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@pytest.mark.parametrize( - "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], -) -def test_model_service_transport_channel_mtls_with_adc(transport_class): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_model_service_grpc_lro_client(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_model_service_grpc_lro_async_client(): - client = ModelServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", - ) - transport = client._client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_model_path(): - project = "squid" - location = "clam" - model = "whelk" - - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) - actual = ModelServiceClient.model_path(project, location, model) - assert expected == actual - - -def test_parse_model_path(): - expected = { - "project": "octopus", - "location": "oyster", - "model": "nudibranch", - } - path = ModelServiceClient.model_path(**expected) - - # Check that the path construction is reversible. - actual = ModelServiceClient.parse_model_path(path) - assert expected == actual - - -def test_client_withDEFAULT_CLIENT_INFO(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object( - transports.ModelServiceTransport, "_prep_wrapped_messages" - ) as prep: - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object( - transports.ModelServiceTransport, "_prep_wrapped_messages" - ) as prep: - transport_class = ModelServiceClient.get_transport_class() - transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py deleted file mode 100644 index c03017b2cc..0000000000 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ /dev/null @@ -1,2065 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import mock - -import grpc -from grpc.experimental import aio -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule - -from google import auth -from google.api_core import client_options -from google.api_core import exceptions -from google.api_core import future -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import operation_async # type: ignore -from google.api_core import operations_v1 -from google.auth import credentials -from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - PipelineServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - PipelineServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers -from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports -from google.cloud.aiplatform_v1beta1.types import deployed_model_ref -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import pipeline_service -from google.cloud.aiplatform_v1beta1.types import pipeline_state -from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import any_pb2 as gp_any # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.rpc import status_pb2 as status # type: ignore - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert PipelineServiceClient._get_default_mtls_endpoint(None) is None - assert ( - PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - ) - - -@pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient] -) -def test_pipeline_service_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_pipeline_service_client_get_transport_class(): - transport = PipelineServiceClient.get_transport_class() - assert transport == transports.PipelineServiceGrpcTransport - - transport = PipelineServiceClient.get_transport_class("grpc") - assert transport == transports.PipelineServiceGrpcTransport - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - PipelineServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceClient), -) -@mock.patch.object( - PipelineServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceAsyncClient), -) -def test_pipeline_service_client_client_options( - client_class, transport_class, transport_name -): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - PipelineServiceClient, - transports.PipelineServiceGrpcTransport, - "grpc", - "true", - ), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - PipelineServiceClient, - transports.PipelineServiceGrpcTransport, - "grpc", - "false", - ), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - PipelineServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceClient), -) -@mock.patch.object( - PipelineServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PipelineServiceAsyncClient), -) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_pipeline_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) - - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_pipeline_service_client_client_options_scopes( - client_class, transport_class, transport_name -): - # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), - ( - PipelineServiceAsyncClient, - transports.PipelineServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_pipeline_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): - # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_pipeline_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = PipelineServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_create_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest -): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - - response = client.create_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == pipeline_service.CreateTrainingPipelineRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_training_pipeline.TrainingPipeline) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.training_task_definition == "training_task_definition_value" - - assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED - - -def test_create_training_pipeline_from_dict(): - test_create_training_pipeline(request_type=dict) - - -@pytest.mark.asyncio -async def test_create_training_pipeline_async(transport: str = "grpc_asyncio"): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.CreateTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - ) - - response = await client.create_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_training_pipeline.TrainingPipeline) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.training_task_definition == "training_task_definition_value" - - assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED - - -def test_create_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_training_pipeline), "__call__" - ) as call: - call.return_value = gca_training_pipeline.TrainingPipeline() - - client.create_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_create_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.CreateTrainingPipelineRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline() - ) - - await client.create_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_create_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_training_pipeline.TrainingPipeline() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) - - -def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_training_pipeline( - pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), - ) - - -@pytest.mark.asyncio -async def test_create_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_training_pipeline.TrainingPipeline() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - gca_training_pipeline.TrainingPipeline() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) - - -@pytest.mark.asyncio -async def test_create_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_training_pipeline( - pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), - ) - - -def test_get_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest -): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - - response = client.get_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == pipeline_service.GetTrainingPipelineRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, training_pipeline.TrainingPipeline) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.training_task_definition == "training_task_definition_value" - - assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED - - -def test_get_training_pipeline_from_dict(): - test_get_training_pipeline(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_training_pipeline_async(transport: str = "grpc_asyncio"): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.GetTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - ) - - response = await client.get_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, training_pipeline.TrainingPipeline) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.training_task_definition == "training_task_definition_value" - - assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED - - -def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.GetTrainingPipelineRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - call.return_value = training_pipeline.TrainingPipeline() - - client.get_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.GetTrainingPipelineRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline() - ) - - await client.get_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = training_pipeline.TrainingPipeline() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = training_pipeline.TrainingPipeline() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - training_pipeline.TrainingPipeline() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", - ) - - -def test_list_training_pipelines( - transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest -): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_training_pipelines(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == pipeline_service.ListTrainingPipelinesRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListTrainingPipelinesPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_training_pipelines_from_dict(): - test_list_training_pipelines(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_training_pipelines_async(transport: str = "grpc_asyncio"): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.ListTrainingPipelinesRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_training_pipelines), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_training_pipelines(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - call.return_value = pipeline_service.ListTrainingPipelinesResponse() - - client.list_training_pipelines(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_training_pipelines_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.ListTrainingPipelinesRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_training_pipelines), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse() - ) - - await client.list_training_pipelines(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_training_pipelines_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = pipeline_service.ListTrainingPipelinesResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_training_pipelines(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_training_pipelines_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_training_pipelines), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = pipeline_service.ListTrainingPipelinesResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - pipeline_service.ListTrainingPipelinesResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_training_pipelines(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_training_pipelines_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", - ) - - -def test_list_training_pipelines_pager(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - next_page_token="abc", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_training_pipelines(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) - - -def test_list_training_pipelines_pages(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - next_page_token="abc", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - ), - RuntimeError, - ) - pages = list(client.list_training_pipelines(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_training_pipelines_async_pager(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_training_pipelines), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - next_page_token="abc", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - ), - RuntimeError, - ) - async_pager = await client.list_training_pipelines(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) - - -@pytest.mark.asyncio -async def test_list_training_pipelines_async_pages(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_training_pipelines), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - next_page_token="abc", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_training_pipelines(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_delete_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest -): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_training_pipeline_from_dict(): - test_delete_training_pipeline(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_training_pipeline_async(transport: str = "grpc_asyncio"): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.DeleteTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_training_pipeline), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.DeleteTrainingPipelineRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", - ) - - -def test_cancel_training_pipeline( - transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest -): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == pipeline_service.CancelTrainingPipelineRequest() - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_training_pipeline_from_dict(): - test_cancel_training_pipeline(request_type=dict) - - -@pytest.mark.asyncio -async def test_cancel_training_pipeline_async(transport: str = "grpc_asyncio"): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.CancelTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - response = await client.cancel_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.CancelTrainingPipelineRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_training_pipeline), "__call__" - ) as call: - call.return_value = None - - client.cancel_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_cancel_training_pipeline_field_headers_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.CancelTrainingPipelineRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_training_pipeline), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - - await client.cancel_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.cancel_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_cancel_training_pipeline_flattened_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.cancel_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.cancel_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_cancel_training_pipeline_flattened_error_async(): - client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PipelineServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PipelineServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = PipelineServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.PipelineServiceGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.PipelineServiceGrpcTransport,) - - -def test_pipeline_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): - transport = transports.PipelineServiceTransport( - credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", - ) - - -def test_pipeline_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.PipelineServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_training_pipeline", - "get_training_pipeline", - "list_training_pipelines", - "delete_training_pipeline", - "cancel_training_pipeline", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_pipeline_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.PipelineServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_pipeline_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.PipelineServiceTransport() - adc.assert_called_once() - - -def test_pipeline_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - PipelineServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id=None, - ) - - -def test_pipeline_service_transport_auth_adc(): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_pipeline_service_host_no_port(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_pipeline_service_host_with_port(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_pipeline_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.PipelineServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.PipelineServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) -def test_pipeline_service_transport_channel_mtls_with_client_cert_source( - transport_class, -): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.PipelineServiceGrpcTransport, - transports.PipelineServiceGrpcAsyncIOTransport, - ], -) -def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_pipeline_service_grpc_lro_client(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_pipeline_service_grpc_lro_async_client(): - client = PipelineServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", - ) - transport = client._client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_model_path(): - project = "squid" - location = "clam" - model = "whelk" - - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) - actual = PipelineServiceClient.model_path(project, location, model) - assert expected == actual - - -def test_parse_model_path(): - expected = { - "project": "octopus", - "location": "oyster", - "model": "nudibranch", - } - path = PipelineServiceClient.model_path(**expected) - - # Check that the path construction is reversible. - actual = PipelineServiceClient.parse_model_path(path) - assert expected == actual - - -def test_training_pipeline_path(): - project = "squid" - location = "clam" - training_pipeline = "whelk" - - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = PipelineServiceClient.training_pipeline_path( - project, location, training_pipeline - ) - assert expected == actual - - -def test_parse_training_pipeline_path(): - expected = { - "project": "octopus", - "location": "oyster", - "training_pipeline": "nudibranch", - } - path = PipelineServiceClient.training_pipeline_path(**expected) - - # Check that the path construction is reversible. - actual = PipelineServiceClient.parse_training_pipeline_path(path) - assert expected == actual - - -def test_client_withDEFAULT_CLIENT_INFO(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object( - transports.PipelineServiceTransport, "_prep_wrapped_messages" - ) as prep: - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object( - transports.PipelineServiceTransport, "_prep_wrapped_messages" - ) as prep: - transport_class = PipelineServiceClient.get_transport_class() - transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py deleted file mode 100644 index f4ac41a0f0..0000000000 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ /dev/null @@ -1,1217 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import mock - -import grpc -from grpc.experimental import aio -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule - -from google import auth -from google.api_core import client_options -from google.api_core import exceptions -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.auth import credentials -from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.prediction_service import ( - PredictionServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.prediction_service import ( - PredictionServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.prediction_service import transports -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import prediction_service -from google.oauth2 import service_account -from google.protobuf import struct_pb2 as struct # type: ignore - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert PredictionServiceClient._get_default_mtls_endpoint(None) is None - assert ( - PredictionServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - PredictionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - PredictionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PredictionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - PredictionServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) - - -@pytest.mark.parametrize( - "client_class", [PredictionServiceClient, PredictionServiceAsyncClient] -) -def test_prediction_service_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_prediction_service_client_get_transport_class(): - transport = PredictionServiceClient.get_transport_class() - assert transport == transports.PredictionServiceGrpcTransport - - transport = PredictionServiceClient.get_transport_class("grpc") - assert transport == transports.PredictionServiceGrpcTransport - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), - ( - PredictionServiceAsyncClient, - transports.PredictionServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - PredictionServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PredictionServiceClient), -) -@mock.patch.object( - PredictionServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PredictionServiceAsyncClient), -) -def test_prediction_service_client_client_options( - client_class, transport_class, transport_name -): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(PredictionServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PredictionServiceClient, "get_transport_class") as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - PredictionServiceClient, - transports.PredictionServiceGrpcTransport, - "grpc", - "true", - ), - ( - PredictionServiceAsyncClient, - transports.PredictionServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - PredictionServiceClient, - transports.PredictionServiceGrpcTransport, - "grpc", - "false", - ), - ( - PredictionServiceAsyncClient, - transports.PredictionServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - PredictionServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PredictionServiceClient), -) -@mock.patch.object( - PredictionServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(PredictionServiceAsyncClient), -) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_prediction_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) - - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), - ( - PredictionServiceAsyncClient, - transports.PredictionServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_prediction_service_client_client_options_scopes( - client_class, transport_class, transport_name -): - # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), - ( - PredictionServiceAsyncClient, - transports.PredictionServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_prediction_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): - # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_prediction_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = PredictionServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_predict( - transport: str = "grpc", request_type=prediction_service.PredictRequest -): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse( - deployed_model_id="deployed_model_id_value", - ) - - response = client.predict(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == prediction_service.PredictRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.PredictResponse) - - assert response.deployed_model_id == "deployed_model_id_value" - - -def test_predict_from_dict(): - test_predict(request_type=dict) - - -@pytest.mark.asyncio -async def test_predict_async(transport: str = "grpc_asyncio"): - client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = prediction_service.PredictRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.PredictResponse( - deployed_model_id="deployed_model_id_value", - ) - ) - - response = await client.predict(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.PredictResponse) - - assert response.deployed_model_id == "deployed_model_id_value" - - -def test_predict_field_headers(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = prediction_service.PredictRequest() - request.endpoint = "endpoint/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.predict), "__call__") as call: - call.return_value = prediction_service.PredictResponse() - - client.predict(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_predict_field_headers_async(): - client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = prediction_service.PredictRequest() - request.endpoint = "endpoint/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.predict), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.PredictResponse() - ) - - await client.predict(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] - - -def test_predict_flattened(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.predict( - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == "endpoint_value" - - assert args[0].instances == [ - struct.Value(null_value=struct.NullValue.NULL_VALUE) - ] - - assert args[0].parameters == struct.Value( - null_value=struct.NullValue.NULL_VALUE - ) - - -def test_predict_flattened_error(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.predict( - prediction_service.PredictRequest(), - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - ) - - -@pytest.mark.asyncio -async def test_predict_flattened_async(): - client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.PredictResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.predict( - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == "endpoint_value" - - assert args[0].instances == [ - struct.Value(null_value=struct.NullValue.NULL_VALUE) - ] - - assert args[0].parameters == struct.Value( - null_value=struct.NullValue.NULL_VALUE - ) - - -@pytest.mark.asyncio -async def test_predict_flattened_error_async(): - client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.predict( - prediction_service.PredictRequest(), - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - ) - - -def test_explain( - transport: str = "grpc", request_type=prediction_service.ExplainRequest -): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse( - deployed_model_id="deployed_model_id_value", - ) - - response = client.explain(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == prediction_service.ExplainRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.ExplainResponse) - - assert response.deployed_model_id == "deployed_model_id_value" - - -def test_explain_from_dict(): - test_explain(request_type=dict) - - -@pytest.mark.asyncio -async def test_explain_async(transport: str = "grpc_asyncio"): - client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = prediction_service.ExplainRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.ExplainResponse( - deployed_model_id="deployed_model_id_value", - ) - ) - - response = await client.explain(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.ExplainResponse) - - assert response.deployed_model_id == "deployed_model_id_value" - - -def test_explain_field_headers(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = prediction_service.ExplainRequest() - request.endpoint = "endpoint/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.explain), "__call__") as call: - call.return_value = prediction_service.ExplainResponse() - - client.explain(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_explain_field_headers_async(): - client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = prediction_service.ExplainRequest() - request.endpoint = "endpoint/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.explain), "__call__") as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.ExplainResponse() - ) - - await client.explain(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] - - -def test_explain_flattened(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.explain( - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == "endpoint_value" - - assert args[0].instances == [ - struct.Value(null_value=struct.NullValue.NULL_VALUE) - ] - - assert args[0].parameters == struct.Value( - null_value=struct.NullValue.NULL_VALUE - ) - - assert args[0].deployed_model_id == "deployed_model_id_value" - - -def test_explain_flattened_error(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.explain( - prediction_service.ExplainRequest(), - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - -@pytest.mark.asyncio -async def test_explain_flattened_async(): - client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._client._transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - prediction_service.ExplainResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.explain( - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].endpoint == "endpoint_value" - - assert args[0].instances == [ - struct.Value(null_value=struct.NullValue.NULL_VALUE) - ] - - assert args[0].parameters == struct.Value( - null_value=struct.NullValue.NULL_VALUE - ) - - assert args[0].deployed_model_id == "deployed_model_id_value" - - -@pytest.mark.asyncio -async def test_explain_flattened_error_async(): - client = PredictionServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.explain( - prediction_service.ExplainRequest(), - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PredictionServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PredictionServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = PredictionServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.PredictionServiceGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.PredictionServiceGrpcTransport, - transports.PredictionServiceGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.PredictionServiceGrpcTransport,) - - -def test_prediction_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): - transport = transports.PredictionServiceTransport( - credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", - ) - - -def test_prediction_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.PredictionServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "predict", - "explain", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - -def test_prediction_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.PredictionServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_prediction_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.PredictionServiceTransport() - adc.assert_called_once() - - -def test_prediction_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - PredictionServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id=None, - ) - - -def test_prediction_service_transport_auth_adc(): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.PredictionServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_prediction_service_host_no_port(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_prediction_service_host_with_port(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_prediction_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.PredictionServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -def test_prediction_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.PredictionServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.PredictionServiceGrpcTransport, - transports.PredictionServiceGrpcAsyncIOTransport, - ], -) -def test_prediction_service_transport_channel_mtls_with_client_cert_source( - transport_class, -): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.PredictionServiceGrpcTransport, - transports.PredictionServiceGrpcAsyncIOTransport, - ], -) -def test_prediction_service_transport_channel_mtls_with_adc(transport_class): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_client_withDEFAULT_CLIENT_INFO(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object( - transports.PredictionServiceTransport, "_prep_wrapped_messages" - ) as prep: - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object( - transports.PredictionServiceTransport, "_prep_wrapped_messages" - ) as prep: - transport_class = PredictionServiceClient.get_transport_class() - transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py deleted file mode 100644 index d17ee90484..0000000000 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ /dev/null @@ -1,2121 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -import mock - -import grpc -from grpc.experimental import aio -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule - -from google import auth -from google.api_core import client_options -from google.api_core import exceptions -from google.api_core import future -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import operation_async # type: ignore -from google.api_core import operations_v1 -from google.auth import credentials -from google.auth.exceptions import MutualTLSChannelError -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( - SpecialistPoolServiceAsyncClient, -) -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( - SpecialistPoolServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool_service -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return ( - "foo.googleapis.com" - if ("localhost" in client.DEFAULT_ENDPOINT) - else client.DEFAULT_ENDPOINT - ) - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) - == api_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) - == api_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) - == sandbox_mtls_endpoint - ) - assert ( - SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) - == non_googleapi - ) - - -@pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient] -) -def test_specialist_pool_service_client_from_service_account_file(client_class): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = client_class.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_specialist_pool_service_client_get_transport_class(): - transport = SpecialistPoolServiceClient.get_transport_class() - assert transport == transports.SpecialistPoolServiceGrpcTransport - - transport = SpecialistPoolServiceClient.get_transport_class("grpc") - assert transport == transports.SpecialistPoolServiceGrpcTransport - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -@mock.patch.object( - SpecialistPoolServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceClient), -) -@mock.patch.object( - SpecialistPoolServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceAsyncClient), -) -def test_specialist_pool_service_client_client_options( - client_class, transport_class, transport_name -): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: - transport = transport_class(credentials=credentials.AnonymousCredentials()) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class() - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} - ): - with pytest.raises(ValueError): - client = client_class() - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name,use_client_cert_env", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - "true", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "true", - ), - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - "false", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - "false", - ), - ], -) -@mock.patch.object( - SpecialistPoolServiceClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceClient), -) -@mock.patch.object( - SpecialistPoolServiceAsyncClient, - "DEFAULT_ENDPOINT", - modify_default_endpoint(SpecialistPoolServiceAsyncClient), -) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_specialist_pool_service_client_mtls_env_auto( - client_class, transport_class, transport_name, use_client_cert_env -): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - options = client_options.ClientOptions( - client_cert_source=client_cert_source_callback - ) - with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) - - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_specialist_pool_service_client_client_options_scopes( - client_class, transport_class, transport_name -): - # Check the case scopes are provided. - options = client_options.ClientOptions(scopes=["1", "2"],) - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -@pytest.mark.parametrize( - "client_class,transport_class,transport_name", - [ - ( - SpecialistPoolServiceClient, - transports.SpecialistPoolServiceGrpcTransport, - "grpc", - ), - ( - SpecialistPoolServiceAsyncClient, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - "grpc_asyncio", - ), - ], -) -def test_specialist_pool_service_client_client_options_credentials_file( - client_class, transport_class, transport_name -): - # Check the case credentials file is provided. - options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" - ) as grpc_transport: - grpc_transport.return_value = None - client = SpecialistPoolServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - ssl_channel_credentials=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - -def test_create_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.CreateSpecialistPoolRequest, -): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_specialist_pool_from_dict(): - test_create_specialist_pool(request_type=dict) - - -@pytest.mark.asyncio -async def test_create_specialist_pool_async(transport: str = "grpc_asyncio"): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.CreateSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.create_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_specialist_pool_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.create_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_create_specialist_pool_field_headers_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.CreateSpecialistPoolRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.create_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_create_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) - - -def test_create_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_specialist_pool( - specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - ) - - -@pytest.mark.asyncio -async def test_create_specialist_pool_flattened_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.create_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) - - -@pytest.mark.asyncio -async def test_create_specialist_pool_flattened_error_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.create_specialist_pool( - specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - ) - - -def test_get_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.GetSpecialistPoolRequest, -): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", - specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], - ) - - response = client.get_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, specialist_pool.SpecialistPool) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.specialist_managers_count == 2662 - - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] - - -def test_get_specialist_pool_from_dict(): - test_get_specialist_pool(request_type=dict) - - -@pytest.mark.asyncio -async def test_get_specialist_pool_async(transport: str = "grpc_asyncio"): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.GetSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", - specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], - ) - ) - - response = await client.get_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, specialist_pool.SpecialistPool) - - assert response.name == "name_value" - - assert response.display_name == "display_name_value" - - assert response.specialist_managers_count == 2662 - - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] - - -def test_get_specialist_pool_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - call.return_value = specialist_pool.SpecialistPool() - - client.get_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_get_specialist_pool_field_headers_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.GetSpecialistPoolRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool() - ) - - await client.get_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool.SpecialistPool() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_specialist_pool(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_get_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_get_specialist_pool_flattened_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.get_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool.SpecialistPool() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool.SpecialistPool() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_specialist_pool(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_get_specialist_pool_flattened_error_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", - ) - - -def test_list_specialist_pools( - transport: str = "grpc", - request_type=specialist_pool_service.ListSpecialistPoolsRequest, -): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_specialist_pools(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListSpecialistPoolsPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_specialist_pools_from_dict(): - test_list_specialist_pools(request_type=dict) - - -@pytest.mark.asyncio -async def test_list_specialist_pools_async(transport: str = "grpc_asyncio"): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.ListSpecialistPoolsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_specialist_pools), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", - ) - ) - - response = await client.list_specialist_pools(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) - - assert response.next_page_token == "next_page_token_value" - - -def test_list_specialist_pools_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - - client.list_specialist_pools(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_list_specialist_pools_field_headers_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.ListSpecialistPoolsRequest() - request.parent = "parent/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_specialist_pools), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse() - ) - - await client.list_specialist_pools(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_specialist_pools_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_specialist_pools(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -def test_list_specialist_pools_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", - ) - - -@pytest.mark.asyncio -async def test_list_specialist_pools_flattened_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_specialist_pools), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - specialist_pool_service.ListSpecialistPoolsResponse() - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_specialist_pools(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].parent == "parent_value" - - -@pytest.mark.asyncio -async def test_list_specialist_pools_flattened_error_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", - ) - - -def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - next_page_token="abc", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - ), - RuntimeError, - ) - - metadata = () - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), - ) - pager = client.list_specialist_pools(request={}) - - assert pager._metadata == metadata - - results = [i for i in pager] - assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) - - -def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - next_page_token="abc", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - ), - RuntimeError, - ) - pages = list(client.list_specialist_pools(request={}).pages) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.asyncio -async def test_list_specialist_pools_async_pager(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_specialist_pools), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - next_page_token="abc", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - ), - RuntimeError, - ) - async_pager = await client.list_specialist_pools(request={},) - assert async_pager.next_page_token == "abc" - responses = [] - async for response in async_pager: - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) - - -@pytest.mark.asyncio -async def test_list_specialist_pools_async_pages(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.list_specialist_pools), - "__call__", - new_callable=mock.AsyncMock, - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - next_page_token="abc", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - ), - RuntimeError, - ) - pages = [] - async for page_ in (await client.list_specialist_pools(request={})).pages: - pages.append(page_) - for page_, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page_.raw_page.next_page_token == token - - -def test_delete_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.DeleteSpecialistPoolRequest, -): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_specialist_pool_from_dict(): - test_delete_specialist_pool(request_type=dict) - - -@pytest.mark.asyncio -async def test_delete_specialist_pool_async(transport: str = "grpc_asyncio"): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.DeleteSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.delete_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_specialist_pool_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.delete_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_delete_specialist_pool_field_headers_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.DeleteSpecialistPoolRequest() - request.name = "name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.delete_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_delete_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.delete_specialist_pool(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -def test_delete_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", - ) - - -@pytest.mark.asyncio -async def test_delete_specialist_pool_flattened_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.delete_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.delete_specialist_pool(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].name == "name_value" - - -@pytest.mark.asyncio -async def test_delete_specialist_pool_flattened_error_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", - ) - - -def test_update_specialist_pool( - transport: str = "grpc", - request_type=specialist_pool_service.UpdateSpecialistPoolRequest, -): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.update_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.update_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_update_specialist_pool_from_dict(): - test_update_specialist_pool(request_type=dict) - - -@pytest.mark.asyncio -async def test_update_specialist_pool_async(transport: str = "grpc_asyncio"): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.UpdateSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - - response = await client.update_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_update_specialist_pool_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = "specialist_pool.name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.update_specialist_pool), "__call__" - ) as call: - call.return_value = operations_pb2.Operation(name="operations/op") - - client.update_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "specialist_pool.name=specialist_pool.name/value", - ) in kw["metadata"] - - -@pytest.mark.asyncio -async def test_update_specialist_pool_field_headers_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.UpdateSpecialistPoolRequest() - request.specialist_pool.name = "specialist_pool.name/value" - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_specialist_pool), "__call__" - ) as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/op") - ) - - await client.update_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - "x-goog-request-params", - "specialist_pool.name=specialist_pool.name/value", - ) in kw["metadata"] - - -def test_update_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.update_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) - - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_specialist_pool( - specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -@pytest.mark.asyncio -async def test_update_specialist_pool_flattened_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._client._transport.update_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name="operations/spam") - ) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) - - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -@pytest.mark.asyncio -async def test_update_specialist_pool_flattened_error_async(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.update_specialist_pool( - specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = SpecialistPoolServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = SpecialistPoolServiceClient( - client_options={"scopes": ["1", "2"]}, transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = SpecialistPoolServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - credentials=credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance(client._transport, transports.SpecialistPoolServiceGrpcTransport,) - - -def test_specialist_pool_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(exceptions.DuplicateCredentialArgs): - transport = transports.SpecialistPoolServiceTransport( - credentials=credentials.AnonymousCredentials(), - credentials_file="credentials.json", - ) - - -def test_specialist_pool_service_base_transport(): - # Instantiate the base transport. - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" - ) as Transport: - Transport.return_value = None - transport = transports.SpecialistPoolServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_specialist_pool", - "get_specialist_pool", - "list_specialist_pools", - "delete_specialist_pool", - "update_specialist_pool", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_specialist_pool_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object( - auth, "load_credentials_from_file" - ) as load_creds, mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - load_creds.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.SpecialistPoolServiceTransport( - credentials_file="credentials.json", quota_project_id="octopus", - ) - load_creds.assert_called_once_with( - "credentials.json", - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_specialist_pool_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(auth, "default") as adc, mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" - ) as Transport: - Transport.return_value = None - adc.return_value = (credentials.AnonymousCredentials(), None) - transport = transports.SpecialistPoolServiceTransport() - adc.assert_called_once() - - -def test_specialist_pool_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - SpecialistPoolServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id=None, - ) - - -def test_specialist_pool_service_transport_auth_adc(): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", quota_project_id="octopus" - ) - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",), - quota_project_id="octopus", - ) - - -def test_specialist_pool_service_host_no_port(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_specialist_pool_service_host_with_port(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.SpecialistPoolServiceGrpcTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") - - # Check that channel is used if provided. - transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) -def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( - transport_class, -): - with mock.patch( - "grpc.ssl_channel_credentials", autospec=True - ) as grpc_ssl_channel_cred: - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(auth, "default") as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.SpecialistPoolServiceGrpcTransport, - transports.SpecialistPoolServiceGrpcAsyncIOTransport, - ], -) -def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object( - transport_class, "create_channel", autospec=True - ) as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=("https://www.googleapis.com/auth/cloud-platform",), - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_specialist_pool_service_grpc_lro_client(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_specialist_pool_service_grpc_lro_async_client(): - client = SpecialistPoolServiceAsyncClient( - credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", - ) - transport = client._client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_specialist_pool_path(): - project = "squid" - location = "clam" - specialist_pool = "whelk" - - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) - actual = SpecialistPoolServiceClient.specialist_pool_path( - project, location, specialist_pool - ) - assert expected == actual - - -def test_parse_specialist_pool_path(): - expected = { - "project": "octopus", - "location": "oyster", - "specialist_pool": "nudibranch", - } - path = SpecialistPoolServiceClient.specialist_pool_path(**expected) - - # Check that the path construction is reversible. - actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) - assert expected == actual - - -def test_client_withDEFAULT_CLIENT_INFO(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object( - transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" - ) as prep: - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object( - transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" - ) as prep: - transport_class = SpecialistPoolServiceClient.get_transport_class() - transport = transport_class( - credentials=credentials.AnonymousCredentials(), client_info=client_info, - ) - prep.assert_called_once_with(client_info)