diff --git a/.kokoro/docs/common.cfg b/.kokoro/docs/common.cfg index b940b1d53f..5adc161f36 100644 --- a/.kokoro/docs/common.cfg +++ b/.kokoro/docs/common.cfg @@ -30,7 +30,7 @@ env_vars: { env_vars: { key: "V2_STAGING_BUCKET" - value: "docs-staging-v2-staging" + value: "docs-staging-v2" } # It will upload the docker image after successful builds. diff --git a/.kokoro/test-samples.sh b/.kokoro/test-samples.sh index 419c1a7e7d..aed13be6d4 100755 --- a/.kokoro/test-samples.sh +++ b/.kokoro/test-samples.sh @@ -28,6 +28,12 @@ if [[ $KOKORO_BUILD_ARTIFACTS_SUBDIR = *"periodic"* ]]; then git checkout $LATEST_RELEASE fi +# Exit early if samples directory doesn't exist +if [ ! -d "./samples" ]; then + echo "No tests run. `./samples` not found" + exit 0 +fi + # Disable buffering, so that the logs stream through. export PYTHONUNBUFFERED=1 @@ -101,4 +107,4 @@ cd "$ROOT" # Workaround for Kokoro permissions issue: delete secrets rm testing/{test-env.sh,client-secrets.json,service-account.json} -exit "$RTN" \ No newline at end of file +exit "$RTN" diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index b3d1f60298..039f436812 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,44 +1,95 @@ -# Contributor Code of Conduct +# Code of Conduct -As contributors and maintainers of this project, -and in the interest of fostering an open and welcoming community, -we pledge to respect all people who contribute through reporting issues, -posting feature requests, updating documentation, -submitting pull requests or patches, and other activities. +## Our Pledge -We are committed to making participation in this project -a harassment-free experience for everyone, -regardless of level of experience, gender, gender identity and expression, -sexual orientation, disability, personal appearance, -body size, race, ethnicity, age, religion, or nationality. +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of +experience, education, socio-economic status, nationality, personal appearance, +race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members Examples of unacceptable behavior by participants include: -* The use of sexualized language or imagery -* Personal attacks -* Trolling or insulting/derogatory comments -* Public or private harassment -* Publishing other's private information, -such as physical or electronic -addresses, without explicit permission -* Other unethical or unprofessional conduct. +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject -comments, commits, code, wiki edits, issues, and other contributions -that are not aligned to this Code of Conduct. -By adopting this Code of Conduct, -project maintainers commit themselves to fairly and consistently -applying these principles to every aspect of managing this project. -Project maintainers who do not follow or enforce the Code of Conduct -may be permanently removed from the project team. - -This code of conduct applies both within project spaces and in public spaces -when an individual is representing the project or its community. - -Instances of abusive, harassing, or otherwise unacceptable behavior -may be reported by opening an issue -or contacting one or more of the project maintainers. - -This Code of Conduct is adapted from the [Contributor Covenant](http://contributor-covenant.org), version 1.2.0, -available at [http://contributor-covenant.org/version/1/2/0/](http://contributor-covenant.org/version/1/2/0/) +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, or to ban temporarily or permanently any +contributor for other behaviors that they deem inappropriate, threatening, +offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when the Project +Steward has a reasonable belief that an individual's behavior may have a +negative impact on the project or its community. + +## Conflict Resolution + +We do not believe that all conflict is bad; healthy debate and disagreement +often yield positive results. However, it is never okay to be disrespectful or +to engage in behavior that violates the project’s code of conduct. + +If you see someone violating the code of conduct, you are encouraged to address +the behavior directly with those involved. Many issues can be resolved quickly +and easily, and this gives people more control over the outcome of their +dispute. If you are unable to resolve the matter for any reason, or if the +behavior is threatening or harassing, report it. We are dedicated to providing +an environment where participants feel welcome and safe. + + +Reports should be directed to *googleapis-stewards@google.com*, the +Project Steward(s) for *Google Cloud Client Libraries*. It is the Project Steward’s duty to +receive and address reported violations of the code of conduct. They will then +work with a committee consisting of representatives from the Open Source +Programs Office and the Google Open Source Strategy team. If for any reason you +are uncomfortable reaching out to the Project Steward, please email +opensource@google.com. + +We will investigate every complaint, but you may not receive a direct response. +We will use our discretion in determining when and how to follow up on reported +incidents, which may range from not taking action to permanent expulsion from +the project and project-sponsored spaces. We will notify the accused of the +report and provide them an opportunity to discuss it before any action is taken. +The identity of the reporter will be omitted from the details of the report +supplied to the accused. In potentially harmful situations, such as ongoing +harassment or threats to anyone's safety, we may take action without notice. + +## Attribution + +This Code of Conduct is adapted from the Contributor Covenant, version 1.4, +available at +https://www.contributor-covenant.org/version/1/4/code-of-conduct.html \ No newline at end of file diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index a23ba7675f..664c9df0a8 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -1,6 +1,27 @@ -Client for Google Cloud Aiplatform API -====================================== +Services for Google Cloud Aiplatform v1beta1 API +================================================ -.. automodule:: google.cloud.aiplatform_v1beta1 +.. automodule:: google.cloud.aiplatform_v1beta1.services.dataset_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.endpoint_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.job_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.migration_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.model_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.pipeline_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.prediction_service + :members: + :inherited-members: +.. automodule:: google.cloud.aiplatform_v1beta1.services.specialist_pool_service :members: :inherited-members: diff --git a/docs/aiplatform_v1beta1/types.rst b/docs/aiplatform_v1beta1/types.rst index df8cb24970..19bab68ada 100644 --- a/docs/aiplatform_v1beta1/types.rst +++ b/docs/aiplatform_v1beta1/types.rst @@ -1,5 +1,6 @@ -Types for Google Cloud Aiplatform API -===================================== +Types for Google Cloud Aiplatform v1beta1 API +============================================= .. automodule:: google.cloud.aiplatform_v1beta1.types :members: + :show-inheritance: diff --git a/docs/conf.py b/docs/conf.py index 5ac4507dba..effa4a8f1f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -349,6 +349,7 @@ "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None,), "grpc": ("https://grpc.io/grpc/python/", None), + "proto-plus": ("https://proto-plus-python.readthedocs.io/en/latest/", None), } diff --git a/google/cloud/aiplatform/__init__.py b/google/cloud/aiplatform/__init__.py index 17172cbb56..ec30029286 100644 --- a/google/cloud/aiplatform/__init__.py +++ b/google/cloud/aiplatform/__init__.py @@ -15,4 +15,6 @@ # limitations under the License. # -__all__ = () +from google.cloud.aiplatform import gapic + +__all__ = (gapic,) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index b99a73164d..f49f90f5eb 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -15,10 +15,10 @@ # limitations under the License. # - from .services.dataset_service import DatasetServiceClient from .services.endpoint_service import EndpointServiceClient from .services.job_service import JobServiceClient +from .services.migration_service import MigrationServiceClient from .services.model_service import ModelServiceClient from .services.pipeline_service import PipelineServiceClient from .services.prediction_service import PredictionServiceClient @@ -81,8 +81,12 @@ from .types.explanation import Explanation from .types.explanation import ExplanationParameters from .types.explanation import ExplanationSpec +from .types.explanation import FeatureNoiseSigma +from .types.explanation import IntegratedGradientsAttribution from .types.explanation import ModelExplanation from .types.explanation import SampledShapleyAttribution +from .types.explanation import SmoothGradConfig +from .types.explanation import XraiAttribution from .types.explanation_metadata import ExplanationMetadata from .types.hyperparameter_tuning_job import HyperparameterTuningJob from .types.io import BigQueryDestination @@ -118,9 +122,18 @@ from .types.machine_resources import AutomaticResources from .types.machine_resources import BatchDedicatedResources from .types.machine_resources import DedicatedResources +from .types.machine_resources import DiskSpec from .types.machine_resources import MachineSpec from .types.machine_resources import ResourcesConsumed from .types.manual_batch_tuning_parameters import ManualBatchTuningParameters +from .types.migratable_resource import MigratableResource +from .types.migration_service import BatchMigrateResourcesOperationMetadata +from .types.migration_service import BatchMigrateResourcesRequest +from .types.migration_service import BatchMigrateResourcesResponse +from .types.migration_service import MigrateResourceRequest +from .types.migration_service import MigrateResourceResponse +from .types.migration_service import SearchMigratableResourcesRequest +from .types.migration_service import SearchMigratableResourcesResponse from .types.model import Model from .types.model import ModelContainerSpec from .types.model import Port @@ -186,6 +199,9 @@ "Attribution", "AutomaticResources", "BatchDedicatedResources", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", "BatchPredictionJob", "BigQueryDestination", "BigQuerySource", @@ -229,6 +245,7 @@ "DeployModelResponse", "DeployedModel", "DeployedModelRef", + "DiskSpec", "Endpoint", "EndpointServiceClient", "EnvVar", @@ -245,6 +262,7 @@ "ExportModelOperationMetadata", "ExportModelRequest", "ExportModelResponse", + "FeatureNoiseSigma", "FilterSplit", "FractionSplit", "GcsDestination", @@ -268,6 +286,7 @@ "ImportDataRequest", "ImportDataResponse", "InputDataConfig", + "IntegratedGradientsAttribution", "JobServiceClient", "JobState", "ListAnnotationsRequest", @@ -299,6 +318,10 @@ "MachineSpec", "ManualBatchTuningParameters", "Measurement", + "MigratableResource", + "MigrateResourceRequest", + "MigrateResourceResponse", + "MigrationServiceClient", "Model", "ModelContainerSpec", "ModelEvaluation", @@ -318,6 +341,9 @@ "SampleConfig", "SampledShapleyAttribution", "Scheduling", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "SmoothGradConfig", "SpecialistPool", "SpecialistPoolServiceClient", "StudySpec", @@ -338,5 +364,6 @@ "UploadModelResponse", "UserActionReference", "WorkerPoolSpec", + "XraiAttribution", "DatasetServiceClient", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py index 8b973db167..597f654cb9 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/__init__.py @@ -16,5 +16,9 @@ # from .client import DatasetServiceClient +from .async_client import DatasetServiceAsyncClient -__all__ = ("DatasetServiceClient",) +__all__ = ( + "DatasetServiceClient", + "DatasetServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py new file mode 100644 index 0000000000..775558e3b1 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -0,0 +1,1032 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers +from google.cloud.aiplatform_v1beta1.types import annotation +from google.cloud.aiplatform_v1beta1.types import annotation_spec +from google.cloud.aiplatform_v1beta1.types import data_item +from google.cloud.aiplatform_v1beta1.types import dataset +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import DatasetServiceGrpcAsyncIOTransport +from .client import DatasetServiceClient + + +class DatasetServiceAsyncClient: + """""" + + _client: DatasetServiceClient + + DEFAULT_ENDPOINT = DatasetServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = DatasetServiceClient.DEFAULT_MTLS_ENDPOINT + + annotation_path = staticmethod(DatasetServiceClient.annotation_path) + parse_annotation_path = staticmethod(DatasetServiceClient.parse_annotation_path) + annotation_spec_path = staticmethod(DatasetServiceClient.annotation_spec_path) + parse_annotation_spec_path = staticmethod( + DatasetServiceClient.parse_annotation_spec_path + ) + data_item_path = staticmethod(DatasetServiceClient.data_item_path) + parse_data_item_path = staticmethod(DatasetServiceClient.parse_data_item_path) + dataset_path = staticmethod(DatasetServiceClient.dataset_path) + parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) + + common_billing_account_path = staticmethod( + DatasetServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + DatasetServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(DatasetServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + DatasetServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + DatasetServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + DatasetServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(DatasetServiceClient.common_project_path) + parse_common_project_path = staticmethod( + DatasetServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(DatasetServiceClient.common_location_path) + parse_common_location_path = staticmethod( + DatasetServiceClient.parse_common_location_path + ) + + from_service_account_file = DatasetServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> DatasetServiceTransport: + """Return the transport used by the client instance. + + Returns: + DatasetServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the dataset service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.DatasetServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = DatasetServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a Dataset. + + Args: + request (:class:`~.dataset_service.CreateDatasetRequest`): + The request object. Request message for + ``DatasetService.CreateDataset``. + parent (:class:`str`): + Required. The resource name of the Location to create + the Dataset in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + dataset (:class:`~.gca_dataset.Dataset`): + Required. The Dataset to create. + This corresponds to the ``dataset`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.gca_dataset.Dataset`: A collection of + DataItems and Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, dataset]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.CreateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if dataset is not None: + request.dataset = dataset + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_dataset.Dataset, + metadata_type=dataset_service.CreateDatasetOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: + r"""Gets a Dataset. + + Args: + request (:class:`~.dataset_service.GetDatasetRequest`): + The request object. Request message for + ``DatasetService.GetDataset``. + name (:class:`str`): + Required. The name of the Dataset + resource. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.dataset.Dataset: + A collection of DataItems and + Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.GetDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: + r"""Updates a Dataset. + + Args: + request (:class:`~.dataset_service.UpdateDatasetRequest`): + The request object. Request message for + ``DatasetService.UpdateDataset``. + dataset (:class:`~.gca_dataset.Dataset`): + Required. The Dataset which replaces + the resource on the server. + This corresponds to the ``dataset`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`~.field_mask.FieldMask`): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + + [FieldMask](https://tinyurl.com/dev-google-protobuf#google.protobuf.FieldMask). + Updatable fields: + + - ``display_name`` + - ``description`` + - ``labels`` + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_dataset.Dataset: + A collection of DataItems and + Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([dataset, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.UpdateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if dataset is not None: + request.dataset = dataset + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsAsyncPager: + r"""Lists Datasets in a Location. + + Args: + request (:class:`~.dataset_service.ListDatasetsRequest`): + The request object. Request message for + ``DatasetService.ListDatasets``. + parent (:class:`str`): + Required. The name of the Dataset's parent resource. + Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListDatasetsAsyncPager: + Response message for + ``DatasetService.ListDatasets``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ListDatasetsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_datasets, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDatasetsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a Dataset. + + Args: + request (:class:`~.dataset_service.DeleteDatasetRequest`): + The request object. Request message for + ``DatasetService.DeleteDataset``. + name (:class:`str`): + Required. The resource name of the Dataset to delete. + Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.DeleteDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Imports data into a Dataset. + + Args: + request (:class:`~.dataset_service.ImportDataRequest`): + The request object. Request message for + ``DatasetService.ImportData``. + name (:class:`str`): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + import_configs (:class:`Sequence[~.dataset.ImportDataConfig]`): + Required. The desired input + locations. The contents of all input + locations will be imported in one batch. + This corresponds to the ``import_configs`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.dataset_service.ImportDataResponse`: + Response message for + ``DatasetService.ImportData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, import_configs]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ImportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + if import_configs: + request.import_configs.extend(import_configs) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.import_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + dataset_service.ImportDataResponse, + metadata_type=dataset_service.ImportDataOperationMetadata, + ) + + # Done; return the response. + return response + + async def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Exports data from a Dataset. + + Args: + request (:class:`~.dataset_service.ExportDataRequest`): + The request object. Request message for + ``DatasetService.ExportData``. + name (:class:`str`): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + export_config (:class:`~.dataset.ExportDataConfig`): + Required. The desired output + location. + This corresponds to the ``export_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.dataset_service.ExportDataResponse`: + Response message for + ``DatasetService.ExportData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, export_config]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ExportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if export_config is not None: + request.export_config = export_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.export_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + dataset_service.ExportDataResponse, + metadata_type=dataset_service.ExportDataOperationMetadata, + ) + + # Done; return the response. + return response + + async def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsAsyncPager: + r"""Lists DataItems in a Dataset. + + Args: + request (:class:`~.dataset_service.ListDataItemsRequest`): + The request object. Request message for + ``DatasetService.ListDataItems``. + parent (:class:`str`): + Required. The resource name of the Dataset to list + DataItems from. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListDataItemsAsyncPager: + Response message for + ``DatasetService.ListDataItems``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ListDataItemsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_data_items, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDataItemsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: + r"""Gets an AnnotationSpec. + + Args: + request (:class:`~.dataset_service.GetAnnotationSpecRequest`): + The request object. Request message for + ``DatasetService.GetAnnotationSpec``. + name (:class:`str`): + Required. The name of the AnnotationSpec resource. + Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.annotation_spec.AnnotationSpec: + Identifies a concept with which + DataItems may be annotated with. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.GetAnnotationSpecRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_annotation_spec, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsAsyncPager: + r"""Lists Annotations belongs to a dataitem + + Args: + request (:class:`~.dataset_service.ListAnnotationsRequest`): + The request object. Request message for + ``DatasetService.ListAnnotations``. + parent (:class:`str`): + Required. The resource name of the DataItem to list + Annotations from. Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListAnnotationsAsyncPager: + Response message for + ``DatasetService.ListAnnotations``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ListAnnotationsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_annotations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListAnnotationsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("DatasetServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 99dcfbc5b6..1e63153291 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import annotation_spec @@ -40,8 +47,9 @@ from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import DatasetServiceTransport +from .transports.base import DatasetServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import DatasetServiceGrpcTransport +from .transports.grpc_asyncio import DatasetServiceGrpcAsyncIOTransport class DatasetServiceClientMeta(type): @@ -56,6 +64,7 @@ class DatasetServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[DatasetServiceTransport]] _transport_registry["grpc"] = DatasetServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: """Return an appropriate transport class. @@ -79,8 +88,38 @@ def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport class DatasetServiceClient(metaclass=DatasetServiceClientMeta): """""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -103,6 +142,76 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> DatasetServiceTransport: + """Return the transport used by the client instance. + + Returns: + DatasetServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def annotation_path( + project: str, location: str, dataset: str, data_item: str, annotation: str, + ) -> str: + """Return a fully-qualified annotation string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) + + @staticmethod + def parse_annotation_path(path: str) -> Dict[str, str]: + """Parse a annotation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def annotation_spec_path( + project: str, location: str, dataset: str, annotation_spec: str, + ) -> str: + """Return a fully-qualified annotation_spec string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) + + @staticmethod + def parse_annotation_spec_path(path: str) -> Dict[str, str]: + """Parse a annotation_spec path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def data_item_path( + project: str, location: str, dataset: str, data_item: str, + ) -> str: + """Return a fully-qualified data_item string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) + + @staticmethod + def parse_data_item_path(path: str) -> Dict[str, str]: + """Parse a data_item path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def dataset_path(project: str, location: str, dataset: str,) -> str: """Return a fully-qualified dataset string.""" @@ -110,12 +219,81 @@ def dataset_path(project: str, location: str, dataset: str,) -> str: project=project, location=location, dataset=dataset, ) + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, DatasetServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the dataset service client. @@ -128,26 +306,102 @@ def __init__( transport (Union[str, ~.DatasetServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, DatasetServiceTransport): - if credentials: + # transport is a DatasetServiceTransport instance. + if credentials or client_options.credentials_file: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_dataset( @@ -197,28 +451,36 @@ def create_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, dataset]): + has_flattened_params = any([parent, dataset]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.CreateDatasetRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.CreateDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.CreateDatasetRequest): + request = dataset_service.CreateDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if dataset is not None: - request.dataset = dataset + if parent is not None: + request.parent = parent + if dataset is not None: + request.dataset = dataset # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_dataset, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -272,25 +534,29 @@ def get_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.GetDatasetRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.GetDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.GetDatasetRequest): + request = dataset_service.GetDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_dataset, default_timeout=None, client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_dataset] # Certain fields should be provided within the metadata header; # add these here. @@ -355,28 +621,38 @@ def update_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([dataset, update_mask]): + has_flattened_params = any([dataset, update_mask]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.UpdateDatasetRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.UpdateDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.UpdateDatasetRequest): + request = dataset_service.UpdateDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if dataset is not None: - request.dataset = dataset - if update_mask is not None: - request.update_mask = update_mask + if dataset is not None: + request.dataset = dataset + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.update_dataset, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.update_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), ) # Send the request. @@ -425,27 +701,29 @@ def list_datasets( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.ListDatasetsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListDatasetsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListDatasetsRequest): + request = dataset_service.ListDatasetsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_datasets, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_datasets] # Certain fields should be provided within the metadata header; # add these here. @@ -459,7 +737,7 @@ def list_datasets( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDatasetsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -518,26 +796,34 @@ def delete_dataset( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.DeleteDatasetRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.DeleteDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.DeleteDatasetRequest): + request = dataset_service.DeleteDatasetRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_dataset, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -603,26 +889,37 @@ def import_data( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, import_configs]): + has_flattened_params = any([name, import_configs]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.ImportDataRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ImportDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ImportDataRequest): + request = dataset_service.ImportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. - # If we have keyword arguments corresponding to fields on the - # request, apply these. + if name is not None: + request.name = name - if name is not None: - request.name = name - if import_configs is not None: - request.import_configs = import_configs + if import_configs: + request.import_configs.extend(import_configs) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.import_data, default_timeout=None, client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.import_data] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -687,26 +984,36 @@ def export_data( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, export_config]): + has_flattened_params = any([name, export_config]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.ExportDataRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ExportDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ExportDataRequest): + request = dataset_service.ExportDataRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name - if export_config is not None: - request.export_config = export_config + if name is not None: + request.name = name + if export_config is not None: + request.export_config = export_config # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.export_data, default_timeout=None, client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.export_data] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -764,27 +1071,29 @@ def list_data_items( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.ListDataItemsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListDataItemsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListDataItemsRequest): + request = dataset_service.ListDataItemsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_data_items, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_data_items] # Certain fields should be provided within the metadata header; # add these here. @@ -798,7 +1107,7 @@ def list_data_items( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataItemsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -843,27 +1152,29 @@ def get_annotation_spec( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.GetAnnotationSpecRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.GetAnnotationSpecRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.GetAnnotationSpecRequest): + request = dataset_service.GetAnnotationSpecRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_annotation_spec, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_annotation_spec] # Certain fields should be provided within the metadata header; # add these here. @@ -919,27 +1230,29 @@ def list_annotations( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = dataset_service.ListAnnotationsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListAnnotationsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListAnnotationsRequest): + request = dataset_service.ListAnnotationsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_annotations, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_annotations] # Certain fields should be provided within the metadata header; # add these here. @@ -953,7 +1266,7 @@ def list_annotations( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListAnnotationsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -961,13 +1274,13 @@ def list_annotations( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index 0dd2e668cc..43c3156caf 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import annotation from google.cloud.aiplatform_v1beta1.types import data_item @@ -43,11 +43,11 @@ class ListDatasetsPager: def __init__( self, - method: Callable[ - [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse - ], + method: Callable[..., dataset_service.ListDatasetsResponse], request: dataset_service.ListDatasetsRequest, response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -58,10 +58,13 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListDatasetsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDatasetsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -71,7 +74,7 @@ def pages(self) -> Iterable[dataset_service.ListDatasetsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[dataset.Dataset]: @@ -82,6 +85,72 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) +class ListDatasetsAsyncPager: + """A pager for iterating through ``list_datasets`` requests. + + This class thinly wraps an initial + :class:`~.dataset_service.ListDatasetsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``datasets`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDatasets`` requests and continue to iterate + through the ``datasets`` field on the + corresponding responses. + + All the usual :class:`~.dataset_service.ListDatasetsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.dataset_service.ListDatasetsRequest`): + The initial request object. + response (:class:`~.dataset_service.ListDatasetsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListDatasetsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListDatasetsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[dataset.Dataset]: + async def async_generator(): + async for page in self.pages: + for response in page.datasets: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + class ListDataItemsPager: """A pager for iterating through ``list_data_items`` requests. @@ -102,12 +171,11 @@ class ListDataItemsPager: def __init__( self, - method: Callable[ - [dataset_service.ListDataItemsRequest], - dataset_service.ListDataItemsResponse, - ], + method: Callable[..., dataset_service.ListDataItemsResponse], request: dataset_service.ListDataItemsRequest, response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -118,10 +186,13 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListDataItemsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListDataItemsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -131,7 +202,7 @@ def pages(self) -> Iterable[dataset_service.ListDataItemsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[data_item.DataItem]: @@ -142,6 +213,72 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) +class ListDataItemsAsyncPager: + """A pager for iterating through ``list_data_items`` requests. + + This class thinly wraps an initial + :class:`~.dataset_service.ListDataItemsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``data_items`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDataItems`` requests and continue to iterate + through the ``data_items`` field on the + corresponding responses. + + All the usual :class:`~.dataset_service.ListDataItemsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.dataset_service.ListDataItemsRequest`): + The initial request object. + response (:class:`~.dataset_service.ListDataItemsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListDataItemsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListDataItemsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[data_item.DataItem]: + async def async_generator(): + async for page in self.pages: + for response in page.data_items: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + class ListAnnotationsPager: """A pager for iterating through ``list_annotations`` requests. @@ -162,12 +299,11 @@ class ListAnnotationsPager: def __init__( self, - method: Callable[ - [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse, - ], + method: Callable[..., dataset_service.ListAnnotationsResponse], request: dataset_service.ListAnnotationsRequest, response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -178,10 +314,13 @@ def __init__( The initial request object. response (:class:`~.dataset_service.ListAnnotationsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = dataset_service.ListAnnotationsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -191,7 +330,7 @@ def pages(self) -> Iterable[dataset_service.ListAnnotationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[annotation.Annotation]: @@ -200,3 +339,69 @@ def __iter__(self) -> Iterable[annotation.Annotation]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListAnnotationsAsyncPager: + """A pager for iterating through ``list_annotations`` requests. + + This class thinly wraps an initial + :class:`~.dataset_service.ListAnnotationsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``annotations`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListAnnotations`` requests and continue to iterate + through the ``annotations`` field on the + corresponding responses. + + All the usual :class:`~.dataset_service.ListAnnotationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.dataset_service.ListAnnotationsRequest`): + The initial request object. + response (:class:`~.dataset_service.ListAnnotationsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListAnnotationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListAnnotationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[annotation.Annotation]: + async def async_generator(): + async for page in self.pages: + for response in page.annotations: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py index 7f1cb8ca21..f8496b801c 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import DatasetServiceTransport from .grpc import DatasetServiceGrpcTransport +from .grpc_asyncio import DatasetServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] _transport_registry["grpc"] = DatasetServiceGrpcTransport +_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport __all__ = ( "DatasetServiceTransport", "DatasetServiceGrpcTransport", + "DatasetServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index f00538959f..8cceeb197c 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -29,7 +33,17 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class DatasetServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class DatasetServiceTransport(abc.ABC): """Abstract transport class for DatasetService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -39,6 +53,11 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, ) -> None: """Instantiate the transport. @@ -49,6 +68,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -57,85 +87,168 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_dataset: gapic_v1.method.wrap_method( + self.create_dataset, default_timeout=None, client_info=client_info, + ), + self.get_dataset: gapic_v1.method.wrap_method( + self.get_dataset, default_timeout=None, client_info=client_info, + ), + self.update_dataset: gapic_v1.method.wrap_method( + self.update_dataset, default_timeout=None, client_info=client_info, + ), + self.list_datasets: gapic_v1.method.wrap_method( + self.list_datasets, default_timeout=None, client_info=client_info, + ), + self.delete_dataset: gapic_v1.method.wrap_method( + self.delete_dataset, default_timeout=None, client_info=client_info, + ), + self.import_data: gapic_v1.method.wrap_method( + self.import_data, default_timeout=None, client_info=client_info, + ), + self.export_data: gapic_v1.method.wrap_method( + self.export_data, default_timeout=None, client_info=client_info, + ), + self.list_data_items: gapic_v1.method.wrap_method( + self.list_data_items, default_timeout=None, client_info=client_info, + ), + self.get_annotation_spec: gapic_v1.method.wrap_method( + self.get_annotation_spec, default_timeout=None, client_info=client_info, + ), + self.list_annotations: gapic_v1.method.wrap_method( + self.list_annotations, default_timeout=None, client_info=client_info, + ), + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property def create_dataset( self, - ) -> typing.Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def get_dataset( self, - ) -> typing.Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: - raise NotImplementedError + ) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], + ]: + raise NotImplementedError() @property def update_dataset( self, - ) -> typing.Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: - raise NotImplementedError + ) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], + ]: + raise NotImplementedError() @property def list_datasets( self, ) -> typing.Callable[ - [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def delete_dataset( self, - ) -> typing.Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def import_data( self, - ) -> typing.Callable[[dataset_service.ImportDataRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def export_data( self, - ) -> typing.Callable[[dataset_service.ExportDataRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def list_data_items( self, ) -> typing.Callable[ - [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_annotation_spec( self, ) -> typing.Callable[ - [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def list_annotations( self, ) -> typing.Callable[ [dataset_service.ListAnnotationsRequest], - dataset_service.ListAnnotationsResponse, + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() __all__ = ("DatasetServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 8110a97c2c..2647c4bd9c 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -29,7 +33,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import DatasetServiceTransport +from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO class DatasetServiceGrpcTransport(DatasetServiceTransport): @@ -43,12 +47,21 @@ class DatasetServiceGrpcTransport(DatasetServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - channel: grpc.Channel = None + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -60,28 +73,123 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. + self._ssl_channel_credentials = ssl_channel_credentials + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - **kwargs + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -91,30 +199,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..1f22b10f3e --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -0,0 +1,534 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import annotation_spec +from google.cloud.aiplatform_v1beta1.types import dataset +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import DatasetServiceGrpcTransport + + +class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): + """gRPC AsyncIO backend transport for DatasetService. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def create_dataset( + self, + ) -> Callable[ + [dataset_service.CreateDatasetRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the create dataset method over gRPC. + + Creates a Dataset. + + Returns: + Callable[[~.CreateDatasetRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/CreateDataset", + request_serializer=dataset_service.CreateDatasetRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_dataset"] + + @property + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], Awaitable[dataset.Dataset]]: + r"""Return a callable for the get dataset method over gRPC. + + Gets a Dataset. + + Returns: + Callable[[~.GetDatasetRequest], + Awaitable[~.Dataset]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetDataset", + request_serializer=dataset_service.GetDatasetRequest.serialize, + response_deserializer=dataset.Dataset.deserialize, + ) + return self._stubs["get_dataset"] + + @property + def update_dataset( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetRequest], Awaitable[gca_dataset.Dataset] + ]: + r"""Return a callable for the update dataset method over gRPC. + + Updates a Dataset. + + Returns: + Callable[[~.UpdateDatasetRequest], + Awaitable[~.Dataset]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/UpdateDataset", + request_serializer=dataset_service.UpdateDatasetRequest.serialize, + response_deserializer=gca_dataset.Dataset.deserialize, + ) + return self._stubs["update_dataset"] + + @property + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], + Awaitable[dataset_service.ListDatasetsResponse], + ]: + r"""Return a callable for the list datasets method over gRPC. + + Lists Datasets in a Location. + + Returns: + Callable[[~.ListDatasetsRequest], + Awaitable[~.ListDatasetsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDatasets", + request_serializer=dataset_service.ListDatasetsRequest.serialize, + response_deserializer=dataset_service.ListDatasetsResponse.deserialize, + ) + return self._stubs["list_datasets"] + + @property + def delete_dataset( + self, + ) -> Callable[ + [dataset_service.DeleteDatasetRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete dataset method over gRPC. + + Deletes a Dataset. + + Returns: + Callable[[~.DeleteDatasetRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/DeleteDataset", + request_serializer=dataset_service.DeleteDatasetRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_dataset"] + + @property + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the import data method over gRPC. + + Imports data into a Dataset. + + Returns: + Callable[[~.ImportDataRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ImportData", + request_serializer=dataset_service.ImportDataRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["import_data"] + + @property + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the export data method over gRPC. + + Exports data from a Dataset. + + Returns: + Callable[[~.ExportDataRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ExportData", + request_serializer=dataset_service.ExportDataRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["export_data"] + + @property + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], + Awaitable[dataset_service.ListDataItemsResponse], + ]: + r"""Return a callable for the list data items method over gRPC. + + Lists DataItems in a Dataset. + + Returns: + Callable[[~.ListDataItemsRequest], + Awaitable[~.ListDataItemsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListDataItems", + request_serializer=dataset_service.ListDataItemsRequest.serialize, + response_deserializer=dataset_service.ListDataItemsResponse.deserialize, + ) + return self._stubs["list_data_items"] + + @property + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + Awaitable[annotation_spec.AnnotationSpec], + ]: + r"""Return a callable for the get annotation spec method over gRPC. + + Gets an AnnotationSpec. + + Returns: + Callable[[~.GetAnnotationSpecRequest], + Awaitable[~.AnnotationSpec]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/GetAnnotationSpec", + request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, + response_deserializer=annotation_spec.AnnotationSpec.deserialize, + ) + return self._stubs["get_annotation_spec"] + + @property + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + Awaitable[dataset_service.ListAnnotationsResponse], + ]: + r"""Return a callable for the list annotations method over gRPC. + + Lists Annotations belongs to a dataitem + + Returns: + Callable[[~.ListAnnotationsRequest], + Awaitable[~.ListAnnotationsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.DatasetService/ListAnnotations", + request_serializer=dataset_service.ListAnnotationsRequest.serialize, + response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, + ) + return self._stubs["list_annotations"] + + +__all__ = ("DatasetServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py index af0b93f5a8..035a5b2388 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/__init__.py @@ -16,5 +16,9 @@ # from .client import EndpointServiceClient +from .async_client import EndpointServiceAsyncClient -__all__ = ("EndpointServiceClient",) +__all__ = ( + "EndpointServiceClient", + "EndpointServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py new file mode 100644 index 0000000000..9056e7a149 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -0,0 +1,831 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.types import endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import EndpointServiceGrpcAsyncIOTransport +from .client import EndpointServiceClient + + +class EndpointServiceAsyncClient: + """""" + + _client: EndpointServiceClient + + DEFAULT_ENDPOINT = EndpointServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = EndpointServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(EndpointServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(EndpointServiceClient.parse_endpoint_path) + model_path = staticmethod(EndpointServiceClient.model_path) + parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) + + common_billing_account_path = staticmethod( + EndpointServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + EndpointServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(EndpointServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + EndpointServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + EndpointServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + EndpointServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(EndpointServiceClient.common_project_path) + parse_common_project_path = staticmethod( + EndpointServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(EndpointServiceClient.common_location_path) + parse_common_location_path = staticmethod( + EndpointServiceClient.parse_common_location_path + ) + + from_service_account_file = EndpointServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> EndpointServiceTransport: + """Return the transport used by the client instance. + + Returns: + EndpointServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the endpoint service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.EndpointServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = EndpointServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates an Endpoint. + + Args: + request (:class:`~.endpoint_service.CreateEndpointRequest`): + The request object. Request message for + ``EndpointService.CreateEndpoint``. + parent (:class:`str`): + Required. The resource name of the Location to create + the Endpoint in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + endpoint (:class:`~.gca_endpoint.Endpoint`): + Required. The Endpoint to create. + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.gca_endpoint.Endpoint`: Models are deployed + into it, and afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, endpoint]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.CreateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if endpoint is not None: + request.endpoint = endpoint + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_endpoint.Endpoint, + metadata_type=endpoint_service.CreateEndpointOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: + r"""Gets an Endpoint. + + Args: + request (:class:`~.endpoint_service.GetEndpointRequest`): + The request object. Request message for + ``EndpointService.GetEndpoint`` + name (:class:`str`): + Required. The name of the Endpoint resource. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.endpoint.Endpoint: + Models are deployed into it, and + afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.GetEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsAsyncPager: + r"""Lists Endpoints in a Location. + + Args: + request (:class:`~.endpoint_service.ListEndpointsRequest`): + The request object. Request message for + ``EndpointService.ListEndpoints``. + parent (:class:`str`): + Required. The resource name of the Location from which + to list the Endpoints. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListEndpointsAsyncPager: + Response message for + ``EndpointService.ListEndpoints``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.ListEndpointsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_endpoints, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListEndpointsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: + r"""Updates an Endpoint. + + Args: + request (:class:`~.endpoint_service.UpdateEndpointRequest`): + The request object. Request message for + ``EndpointService.UpdateEndpoint``. + endpoint (:class:`~.gca_endpoint.Endpoint`): + Required. The Endpoint which replaces + the resource on the server. + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`~.field_mask.FieldMask`): + Required. The update mask applies to + the resource. + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_endpoint.Endpoint: + Models are deployed into it, and + afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.UpdateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes an Endpoint. + + Args: + request (:class:`~.endpoint_service.DeleteEndpointRequest`): + The request object. Request message for + ``EndpointService.DeleteEndpoint``. + name (:class:`str`): + Required. The name of the Endpoint resource to be + deleted. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.DeleteEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deploys a Model into this Endpoint, creating a + DeployedModel within it. + + Args: + request (:class:`~.endpoint_service.DeployModelRequest`): + The request object. Request message for + ``EndpointService.DeployModel``. + endpoint (:class:`str`): + Required. The name of the Endpoint resource into which + to deploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model (:class:`~.gca_endpoint.DeployedModel`): + Required. The DeployedModel to be created within the + Endpoint. Note that + ``Endpoint.traffic_split`` + must be updated for the DeployedModel to start receiving + traffic, either as part of this call, or via + ``EndpointService.UpdateEndpoint``. + This corresponds to the ``deployed_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + traffic_split (:class:`Sequence[~.endpoint_service.DeployModelRequest.TrafficSplitEntry]`): + A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that + DeployedModel. + + If this field is non-empty, then the Endpoint's + ``traffic_split`` + will be overwritten with it. To refer to the ID of the + just being deployed Model, a "0" should be used, and the + actual ID of the new DeployedModel will be filled in its + place by this method. The traffic percentage values must + add up to 100. + + If this field is empty, then the Endpoint's + ``traffic_split`` + is not updated. + This corresponds to the ``traffic_split`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.endpoint_service.DeployModelResponse`: + Response message for + ``EndpointService.DeployModel``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model, traffic_split]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.DeployModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + + if traffic_split: + request.traffic_split.update(traffic_split) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.deploy_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + endpoint_service.DeployModelResponse, + metadata_type=endpoint_service.DeployModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Undeploys a Model from an Endpoint, removing a + DeployedModel from it, and freeing all resources it's + using. + + Args: + request (:class:`~.endpoint_service.UndeployModelRequest`): + The request object. Request message for + ``EndpointService.UndeployModel``. + endpoint (:class:`str`): + Required. The name of the Endpoint resource from which + to undeploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model_id (:class:`str`): + Required. The ID of the DeployedModel + to be undeployed from the Endpoint. + This corresponds to the ``deployed_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + traffic_split (:class:`Sequence[~.endpoint_service.UndeployModelRequest.TrafficSplitEntry]`): + If this field is provided, then the Endpoint's + ``traffic_split`` + will be overwritten with it. If last DeployedModel is + being undeployed from the Endpoint, the + [Endpoint.traffic_split] will always end up empty when + this call returns. A DeployedModel will be successfully + undeployed only if it doesn't have any traffic assigned + to it when this method executes, or if this field + unassigns any traffic to it. + This corresponds to the ``traffic_split`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.endpoint_service.UndeployModelResponse`: + Response message for + ``EndpointService.UndeployModel``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.UndeployModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + if traffic_split: + request.traffic_split.update(traffic_split) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.undeploy_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + endpoint_service.UndeployModelResponse, + metadata_type=endpoint_service.UndeployModelOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("EndpointServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 0ed3efe87a..5ea003b827 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint @@ -36,8 +43,9 @@ from google.protobuf import field_mask_pb2 as field_mask # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import EndpointServiceTransport +from .transports.base import EndpointServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import EndpointServiceGrpcTransport +from .transports.grpc_asyncio import EndpointServiceGrpcAsyncIOTransport class EndpointServiceClientMeta(type): @@ -52,6 +60,7 @@ class EndpointServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[EndpointServiceTransport]] _transport_registry["grpc"] = EndpointServiceGrpcTransport + _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: """Return an appropriate transport class. @@ -75,8 +84,38 @@ def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTranspor class EndpointServiceClient(metaclass=EndpointServiceClientMeta): """""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -99,6 +138,15 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> EndpointServiceTransport: + """Return the transport used by the client instance. + + Returns: + EndpointServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod def endpoint_path(project: str, location: str, endpoint: str,) -> str: """Return a fully-qualified endpoint string.""" @@ -106,12 +154,97 @@ def endpoint_path(project: str, location: str, endpoint: str,) -> str: project=project, location=location, endpoint=endpoint, ) + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str, str]: + """Parse a endpoint path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, EndpointServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the endpoint service client. @@ -124,26 +257,102 @@ def __init__( transport (Union[str, ~.EndpointServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, EndpointServiceTransport): - if credentials: + # transport is a EndpointServiceTransport instance. + if credentials or client_options.credentials_file: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_endpoint( @@ -194,28 +403,36 @@ def create_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, endpoint]): + has_flattened_params = any([parent, endpoint]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = endpoint_service.CreateEndpointRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.CreateEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.CreateEndpointRequest): + request = endpoint_service.CreateEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if endpoint is not None: - request.endpoint = endpoint + if parent is not None: + request.parent = parent + if endpoint is not None: + request.endpoint = endpoint # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_endpoint, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -270,27 +487,29 @@ def get_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = endpoint_service.GetEndpointRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.GetEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.GetEndpointRequest): + request = endpoint_service.GetEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_endpoint, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_endpoint] # Certain fields should be provided within the metadata header; # add these here. @@ -345,27 +564,29 @@ def list_endpoints( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = endpoint_service.ListEndpointsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.ListEndpointsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.ListEndpointsRequest): + request = endpoint_service.ListEndpointsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_endpoints, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_endpoints] # Certain fields should be provided within the metadata header; # add these here. @@ -379,7 +600,7 @@ def list_endpoints( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListEndpointsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -430,28 +651,38 @@ def update_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, update_mask]): + has_flattened_params = any([endpoint, update_mask]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = endpoint_service.UpdateEndpointRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.UpdateEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.UpdateEndpointRequest): + request = endpoint_service.UpdateEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if update_mask is not None: - request.update_mask = update_mask + if endpoint is not None: + request.endpoint = endpoint + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.update_endpoint, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.update_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), ) # Send the request. @@ -513,26 +744,34 @@ def delete_endpoint( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = endpoint_service.DeleteEndpointRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.DeleteEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.DeleteEndpointRequest): + request = endpoint_service.DeleteEndpointRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_endpoint, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -625,30 +864,39 @@ def deploy_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, deployed_model, traffic_split]): + has_flattened_params = any([endpoint, deployed_model, traffic_split]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = endpoint_service.DeployModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.DeployModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.DeployModelRequest): + request = endpoint_service.DeployModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if deployed_model is not None: - request.deployed_model = deployed_model - if traffic_split is not None: - request.traffic_split = traffic_split + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + + if traffic_split: + request.traffic_split.update(traffic_split) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.deploy_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.deploy_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. @@ -732,30 +980,39 @@ def undeploy_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, deployed_model_id, traffic_split]): + has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = endpoint_service.UndeployModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.UndeployModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.UndeployModelRequest): + request = endpoint_service.UndeployModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if deployed_model_id is not None: - request.deployed_model_id = deployed_model_id - if traffic_split is not None: - request.traffic_split = traffic_split + if endpoint is not None: + request.endpoint = endpoint + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + if traffic_split: + request.traffic_split.update(traffic_split) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.undeploy_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.undeploy_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. @@ -774,13 +1031,13 @@ def undeploy_model( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index 4c797e56fd..86320c2178 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -41,12 +41,11 @@ class ListEndpointsPager: def __init__( self, - method: Callable[ - [endpoint_service.ListEndpointsRequest], - endpoint_service.ListEndpointsResponse, - ], + method: Callable[..., endpoint_service.ListEndpointsResponse], request: endpoint_service.ListEndpointsRequest, response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -57,10 +56,13 @@ def __init__( The initial request object. response (:class:`~.endpoint_service.ListEndpointsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = endpoint_service.ListEndpointsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -70,7 +72,7 @@ def pages(self) -> Iterable[endpoint_service.ListEndpointsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[endpoint.Endpoint]: @@ -79,3 +81,69 @@ def __iter__(self) -> Iterable[endpoint.Endpoint]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListEndpointsAsyncPager: + """A pager for iterating through ``list_endpoints`` requests. + + This class thinly wraps an initial + :class:`~.endpoint_service.ListEndpointsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``endpoints`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListEndpoints`` requests and continue to iterate + through the ``endpoints`` field on the + corresponding responses. + + All the usual :class:`~.endpoint_service.ListEndpointsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.endpoint_service.ListEndpointsRequest`): + The initial request object. + response (:class:`~.endpoint_service.ListEndpointsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = endpoint_service.ListEndpointsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[endpoint_service.ListEndpointsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[endpoint.Endpoint]: + async def async_generator(): + async for page in self.pages: + for response in page.endpoints: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py index 62eff450a6..70a87e920e 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import EndpointServiceTransport from .grpc import EndpointServiceGrpcTransport +from .grpc_asyncio import EndpointServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] _transport_registry["grpc"] = EndpointServiceGrpcTransport +_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport __all__ = ( "EndpointServiceTransport", "EndpointServiceGrpcTransport", + "EndpointServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index 43baa080e0..63965464b7 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -28,7 +32,17 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class EndpointServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class EndpointServiceTransport(abc.ABC): """Abstract transport class for EndpointService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -38,6 +52,11 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, ) -> None: """Instantiate the transport. @@ -48,6 +67,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -56,66 +86,123 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_endpoint: gapic_v1.method.wrap_method( + self.create_endpoint, default_timeout=None, client_info=client_info, + ), + self.get_endpoint: gapic_v1.method.wrap_method( + self.get_endpoint, default_timeout=None, client_info=client_info, + ), + self.list_endpoints: gapic_v1.method.wrap_method( + self.list_endpoints, default_timeout=None, client_info=client_info, + ), + self.update_endpoint: gapic_v1.method.wrap_method( + self.update_endpoint, default_timeout=None, client_info=client_info, + ), + self.delete_endpoint: gapic_v1.method.wrap_method( + self.delete_endpoint, default_timeout=None, client_info=client_info, + ), + self.deploy_model: gapic_v1.method.wrap_method( + self.deploy_model, default_timeout=None, client_info=client_info, + ), + self.undeploy_model: gapic_v1.method.wrap_method( + self.undeploy_model, default_timeout=None, client_info=client_info, + ), + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property def create_endpoint( self, ) -> typing.Callable[ - [endpoint_service.CreateEndpointRequest], operations.Operation + [endpoint_service.CreateEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_endpoint( self, - ) -> typing.Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: - raise NotImplementedError + ) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], + ]: + raise NotImplementedError() @property def list_endpoints( self, ) -> typing.Callable[ - [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def update_endpoint( self, ) -> typing.Callable[ - [endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint + [endpoint_service.UpdateEndpointRequest], + typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], ]: - raise NotImplementedError + raise NotImplementedError() @property def delete_endpoint( self, ) -> typing.Callable[ - [endpoint_service.DeleteEndpointRequest], operations.Operation + [endpoint_service.DeleteEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def deploy_model( self, - ) -> typing.Callable[[endpoint_service.DeployModelRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def undeploy_model( self, - ) -> typing.Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() __all__ = ("EndpointServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 9bde20f31f..70915facf0 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -28,7 +32,7 @@ from google.cloud.aiplatform_v1beta1.types import endpoint_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import EndpointServiceTransport +from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO class EndpointServiceGrpcTransport(EndpointServiceTransport): @@ -42,12 +46,21 @@ class EndpointServiceGrpcTransport(EndpointServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - channel: grpc.Channel = None + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -59,28 +72,123 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. + self._ssl_channel_credentials = ssl_channel_credentials + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - **kwargs + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -90,30 +198,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..f4e362281b --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -0,0 +1,453 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import EndpointServiceGrpcTransport + + +class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): + """gRPC AsyncIO backend transport for EndpointService. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def create_endpoint( + self, + ) -> Callable[ + [endpoint_service.CreateEndpointRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the create endpoint method over gRPC. + + Creates an Endpoint. + + Returns: + Callable[[~.CreateEndpointRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/CreateEndpoint", + request_serializer=endpoint_service.CreateEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_endpoint"] + + @property + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], Awaitable[endpoint.Endpoint]]: + r"""Return a callable for the get endpoint method over gRPC. + + Gets an Endpoint. + + Returns: + Callable[[~.GetEndpointRequest], + Awaitable[~.Endpoint]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/GetEndpoint", + request_serializer=endpoint_service.GetEndpointRequest.serialize, + response_deserializer=endpoint.Endpoint.deserialize, + ) + return self._stubs["get_endpoint"] + + @property + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], + Awaitable[endpoint_service.ListEndpointsResponse], + ]: + r"""Return a callable for the list endpoints method over gRPC. + + Lists Endpoints in a Location. + + Returns: + Callable[[~.ListEndpointsRequest], + Awaitable[~.ListEndpointsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/ListEndpoints", + request_serializer=endpoint_service.ListEndpointsRequest.serialize, + response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, + ) + return self._stubs["list_endpoints"] + + @property + def update_endpoint( + self, + ) -> Callable[ + [endpoint_service.UpdateEndpointRequest], Awaitable[gca_endpoint.Endpoint] + ]: + r"""Return a callable for the update endpoint method over gRPC. + + Updates an Endpoint. + + Returns: + Callable[[~.UpdateEndpointRequest], + Awaitable[~.Endpoint]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UpdateEndpoint", + request_serializer=endpoint_service.UpdateEndpointRequest.serialize, + response_deserializer=gca_endpoint.Endpoint.deserialize, + ) + return self._stubs["update_endpoint"] + + @property + def delete_endpoint( + self, + ) -> Callable[ + [endpoint_service.DeleteEndpointRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete endpoint method over gRPC. + + Deletes an Endpoint. + + Returns: + Callable[[~.DeleteEndpointRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeleteEndpoint", + request_serializer=endpoint_service.DeleteEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_endpoint"] + + @property + def deploy_model( + self, + ) -> Callable[ + [endpoint_service.DeployModelRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the deploy model method over gRPC. + + Deploys a Model into this Endpoint, creating a + DeployedModel within it. + + Returns: + Callable[[~.DeployModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/DeployModel", + request_serializer=endpoint_service.DeployModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["deploy_model"] + + @property + def undeploy_model( + self, + ) -> Callable[ + [endpoint_service.UndeployModelRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the undeploy model method over gRPC. + + Undeploys a Model from an Endpoint, removing a + DeployedModel from it, and freeing all resources it's + using. + + Returns: + Callable[[~.UndeployModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.EndpointService/UndeployModel", + request_serializer=endpoint_service.UndeployModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["undeploy_model"] + + +__all__ = ("EndpointServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py index bf1248d281..5f157047f5 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/__init__.py @@ -16,5 +16,9 @@ # from .client import JobServiceClient +from .async_client import JobServiceAsyncClient -__all__ = ("JobServiceClient",) +__all__ = ( + "JobServiceClient", + "JobServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py new file mode 100644 index 0000000000..d988c81d3c --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -0,0 +1,1860 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.job_service import pagers +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1beta1.types import completion_stats +from google.cloud.aiplatform_v1beta1.types import custom_job +from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import job_state +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import study +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + +from .transports.base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import JobServiceGrpcAsyncIOTransport +from .client import JobServiceClient + + +class JobServiceAsyncClient: + """A service for creating and managing AI Platform's jobs.""" + + _client: JobServiceClient + + DEFAULT_ENDPOINT = JobServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT + + batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) + parse_batch_prediction_job_path = staticmethod( + JobServiceClient.parse_batch_prediction_job_path + ) + custom_job_path = staticmethod(JobServiceClient.custom_job_path) + parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) + data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) + parse_data_labeling_job_path = staticmethod( + JobServiceClient.parse_data_labeling_job_path + ) + dataset_path = staticmethod(JobServiceClient.dataset_path) + parse_dataset_path = staticmethod(JobServiceClient.parse_dataset_path) + hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.hyperparameter_tuning_job_path + ) + parse_hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.parse_hyperparameter_tuning_job_path + ) + model_path = staticmethod(JobServiceClient.model_path) + parse_model_path = staticmethod(JobServiceClient.parse_model_path) + + common_billing_account_path = staticmethod( + JobServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + JobServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(JobServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(JobServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(JobServiceClient.common_organization_path) + parse_common_organization_path = staticmethod( + JobServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(JobServiceClient.common_project_path) + parse_common_project_path = staticmethod(JobServiceClient.parse_common_project_path) + + common_location_path = staticmethod(JobServiceClient.common_location_path) + parse_common_location_path = staticmethod( + JobServiceClient.parse_common_location_path + ) + + from_service_account_file = JobServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> JobServiceTransport: + """Return the transport used by the client instance. + + Returns: + JobServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(JobServiceClient).get_transport_class, type(JobServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the job service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.JobServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = JobServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: + r"""Creates a CustomJob. A created CustomJob right away + will be attempted to be run. + + Args: + request (:class:`~.job_service.CreateCustomJobRequest`): + The request object. Request message for + ``JobService.CreateCustomJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the CustomJob in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + custom_job (:class:`~.gca_custom_job.CustomJob`): + Required. The CustomJob to create. + This corresponds to the ``custom_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_custom_job.CustomJob: + Represents a job that runs custom + workloads such as a Docker container or + a Python package. A CustomJob can have + multiple worker pools and each worker + pool can have its own machine and input + spec. A CustomJob will be cleaned up + once the job enters terminal state + (failed or succeeded). + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, custom_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CreateCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if custom_job is not None: + request.custom_job = custom_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: + r"""Gets a CustomJob. + + Args: + request (:class:`~.job_service.GetCustomJobRequest`): + The request object. Request message for + ``JobService.GetCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob resource. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.custom_job.CustomJob: + Represents a job that runs custom + workloads such as a Docker container or + a Python package. A CustomJob can have + multiple worker pools and each worker + pool can have its own machine and input + spec. A CustomJob will be cleaned up + once the job enters terminal state + (failed or succeeded). + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.GetCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsAsyncPager: + r"""Lists CustomJobs in a Location. + + Args: + request (:class:`~.job_service.ListCustomJobsRequest`): + The request object. Request message for + ``JobService.ListCustomJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + CustomJobs from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListCustomJobsAsyncPager: + Response message for + ``JobService.ListCustomJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.ListCustomJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_custom_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListCustomJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a CustomJob. + + Args: + request (:class:`~.job_service.DeleteCustomJobRequest`): + The request object. Request message for + ``JobService.DeleteCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob resource to be + deleted. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.DeleteCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a CustomJob. Starts asynchronous cancellation on the + CustomJob. The server makes a best effort to cancel the job, but + success is not guaranteed. Clients can use + ``JobService.GetCustomJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the CustomJob is not deleted; instead it becomes a + job with a + ``CustomJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``CustomJob.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`~.job_service.CancelCustomJobRequest`): + The request object. Request message for + ``JobService.CancelCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob to cancel. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CancelCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: + r"""Creates a DataLabelingJob. + + Args: + request (:class:`~.job_service.CreateDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.CreateDataLabelingJob][]. + parent (:class:`str`): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + data_labeling_job (:class:`~.gca_data_labeling_job.DataLabelingJob`): + Required. The DataLabelingJob to + create. + This corresponds to the ``data_labeling_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_data_labeling_job.DataLabelingJob: + DataLabelingJob is used to trigger a + human labeling job on unlabeled data + from the following Dataset: + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, data_labeling_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CreateDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if data_labeling_job is not None: + request.data_labeling_job = data_labeling_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: + r"""Gets a DataLabelingJob. + + Args: + request (:class:`~.job_service.GetDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.GetDataLabelingJob][]. + name (:class:`str`): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.data_labeling_job.DataLabelingJob: + DataLabelingJob is used to trigger a + human labeling job on unlabeled data + from the following Dataset: + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.GetDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsAsyncPager: + r"""Lists DataLabelingJobs in a Location. + + Args: + request (:class:`~.job_service.ListDataLabelingJobsRequest`): + The request object. Request message for + [DataLabelingJobService.ListDataLabelingJobs][]. + parent (:class:`str`): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListDataLabelingJobsAsyncPager: + Response message for + ``JobService.ListDataLabelingJobs``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.ListDataLabelingJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_data_labeling_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDataLabelingJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a DataLabelingJob. + + Args: + request (:class:`~.job_service.DeleteDataLabelingJobRequest`): + The request object. Request message for + ``JobService.DeleteDataLabelingJob``. + name (:class:`str`): + Required. The name of the DataLabelingJob to be deleted. + Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.DeleteDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a DataLabelingJob. Success of cancellation is + not guaranteed. + + Args: + request (:class:`~.job_service.CancelDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.CancelDataLabelingJob][]. + name (:class:`str`): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CancelDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + r"""Creates a HyperparameterTuningJob + + Args: + request (:class:`~.job_service.CreateHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.CreateHyperparameterTuningJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the HyperparameterTuningJob in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + hyperparameter_tuning_job (:class:`~.gca_hyperparameter_tuning_job.HyperparameterTuningJob`): + Required. The HyperparameterTuningJob + to create. + This corresponds to the ``hyperparameter_tuning_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_hyperparameter_tuning_job.HyperparameterTuningJob: + Represents a HyperparameterTuningJob. + A HyperparameterTuningJob has a Study + specification and multiple CustomJobs + with identical CustomJob specification. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, hyperparameter_tuning_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CreateHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if hyperparameter_tuning_job is not None: + request.hyperparameter_tuning_job = hyperparameter_tuning_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + r"""Gets a HyperparameterTuningJob + + Args: + request (:class:`~.job_service.GetHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.GetHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob + resource. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.hyperparameter_tuning_job.HyperparameterTuningJob: + Represents a HyperparameterTuningJob. + A HyperparameterTuningJob has a Study + specification and multiple CustomJobs + with identical CustomJob specification. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.GetHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsAsyncPager: + r"""Lists HyperparameterTuningJobs in a Location. + + Args: + request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): + The request object. Request message for + ``JobService.ListHyperparameterTuningJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + HyperparameterTuningJobs from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListHyperparameterTuningJobsAsyncPager: + Response message for + ``JobService.ListHyperparameterTuningJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.ListHyperparameterTuningJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_hyperparameter_tuning_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListHyperparameterTuningJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a HyperparameterTuningJob. + + Args: + request (:class:`~.job_service.DeleteHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.DeleteHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob + resource to be deleted. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.DeleteHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a HyperparameterTuningJob. Starts asynchronous + cancellation on the HyperparameterTuningJob. The server makes a + best effort to cancel the job, but success is not guaranteed. + Clients can use + ``JobService.GetHyperparameterTuningJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the HyperparameterTuningJob is not deleted; + instead it becomes a job with a + ``HyperparameterTuningJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``HyperparameterTuningJob.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`~.job_service.CancelHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.CancelHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob to + cancel. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CancelHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: + r"""Creates a BatchPredictionJob. A BatchPredictionJob + once created will right away be attempted to start. + + Args: + request (:class:`~.job_service.CreateBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.CreateBatchPredictionJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the BatchPredictionJob in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + batch_prediction_job (:class:`~.gca_batch_prediction_job.BatchPredictionJob`): + Required. The BatchPredictionJob to + create. + This corresponds to the ``batch_prediction_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_batch_prediction_job.BatchPredictionJob: + A job that uses a + ``Model`` + to produce predictions on multiple [input + instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. + If predictions for significant portion of the instances + fail, the job may finish without attempting predictions + for all remaining instances. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, batch_prediction_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CreateBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if batch_prediction_job is not None: + request.batch_prediction_job = batch_prediction_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: + r"""Gets a BatchPredictionJob + + Args: + request (:class:`~.job_service.GetBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.GetBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob resource. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.batch_prediction_job.BatchPredictionJob: + A job that uses a + ``Model`` + to produce predictions on multiple [input + instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. + If predictions for significant portion of the instances + fail, the job may finish without attempting predictions + for all remaining instances. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.GetBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsAsyncPager: + r"""Lists BatchPredictionJobs in a Location. + + Args: + request (:class:`~.job_service.ListBatchPredictionJobsRequest`): + The request object. Request message for + ``JobService.ListBatchPredictionJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + BatchPredictionJobs from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListBatchPredictionJobsAsyncPager: + Response message for + ``JobService.ListBatchPredictionJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.ListBatchPredictionJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_batch_prediction_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListBatchPredictionJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a BatchPredictionJob. Can only be called on + jobs that already finished. + + Args: + request (:class:`~.job_service.DeleteBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.DeleteBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob resource to + be deleted. Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.DeleteBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a BatchPredictionJob. + + Starts asynchronous cancellation on the BatchPredictionJob. The + server makes the best effort to cancel the job, but success is + not guaranteed. Clients can use + ``JobService.GetBatchPredictionJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On a successful + cancellation, the BatchPredictionJob is not deleted;instead its + ``BatchPredictionJob.state`` + is set to ``CANCELLED``. Any files already outputted by the job + are not deleted. + + Args: + request (:class:`~.job_service.CancelBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.CancelBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob to cancel. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CancelBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("JobServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index b56a9a7871..a1eb7c38ce 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.job_service import pagers from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import ( @@ -39,6 +46,7 @@ from google.cloud.aiplatform_v1beta1.types import ( data_labeling_job as gca_data_labeling_job, ) +from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import ( hyperparameter_tuning_job as gca_hyperparameter_tuning_job, @@ -55,8 +63,9 @@ from google.rpc import status_pb2 as status # type: ignore from google.type import money_pb2 as money # type: ignore -from .transports.base import JobServiceTransport +from .transports.base import JobServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import JobServiceGrpcTransport +from .transports.grpc_asyncio import JobServiceGrpcAsyncIOTransport class JobServiceClientMeta(type): @@ -69,6 +78,7 @@ class JobServiceClientMeta(type): _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] _transport_registry["grpc"] = JobServiceGrpcTransport + _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: """Return an appropriate transport class. @@ -92,8 +102,38 @@ def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: class JobServiceClient(metaclass=JobServiceClientMeta): """A service for creating and managing AI Platform's jobs.""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -116,12 +156,14 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file - @staticmethod - def custom_job_path(project: str, location: str, custom_job: str,) -> str: - """Return a fully-qualified custom_job string.""" - return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) + @property + def transport(self) -> JobServiceTransport: + """Return the transport used by the client instance. + + Returns: + JobServiceTransport: The transport used by the client instance. + """ + return self._transport @staticmethod def batch_prediction_job_path( @@ -134,6 +176,65 @@ def batch_prediction_job_path( batch_prediction_job=batch_prediction_job, ) + @staticmethod + def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: + """Parse a batch_prediction_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def custom_job_path(project: str, location: str, custom_job: str,) -> str: + """Return a fully-qualified custom_job string.""" + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + + @staticmethod + def parse_custom_job_path(path: str) -> Dict[str, str]: + """Parse a custom_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def data_labeling_job_path( + project: str, location: str, data_labeling_job: str, + ) -> str: + """Return a fully-qualified data_labeling_job string.""" + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + + @staticmethod + def parse_data_labeling_job_path(path: str) -> Dict[str, str]: + """Parse a data_labeling_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, location: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def hyperparameter_tuning_job_path( project: str, location: str, hyperparameter_tuning_job: str, @@ -146,20 +247,96 @@ def hyperparameter_tuning_job_path( ) @staticmethod - def data_labeling_job_path( - project: str, location: str, data_labeling_job: str, - ) -> str: - """Return a fully-qualified data_labeling_job string.""" - return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: + """Parse a hyperparameter_tuning_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, ) + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, JobServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the job service client. @@ -172,26 +349,102 @@ def __init__( transport (Union[str, ~.JobServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, JobServiceTransport): - if credentials: + # transport is a JobServiceTransport instance. + if credentials or client_options.credentials_file: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_custom_job( @@ -245,28 +498,36 @@ def create_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, custom_job]): + has_flattened_params = any([parent, custom_job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.CreateCustomJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateCustomJobRequest): + request = job_service.CreateCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if custom_job is not None: - request.custom_job = custom_job + if parent is not None: + request.parent = parent + if custom_job is not None: + request.custom_job = custom_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_custom_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -318,27 +579,29 @@ def get_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.GetCustomJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetCustomJobRequest): + request = job_service.GetCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_custom_job, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_custom_job] # Certain fields should be provided within the metadata header; # add these here. @@ -393,27 +656,29 @@ def list_custom_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.ListCustomJobsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListCustomJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListCustomJobsRequest): + request = job_service.ListCustomJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_custom_jobs, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_custom_jobs] # Certain fields should be provided within the metadata header; # add these here. @@ -427,7 +692,7 @@ def list_custom_jobs( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListCustomJobsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -486,26 +751,34 @@ def delete_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.DeleteCustomJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteCustomJobRequest): + request = job_service.DeleteCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_custom_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -565,26 +838,34 @@ def cancel_custom_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.CancelCustomJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelCustomJobRequest): + request = job_service.CancelCustomJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_custom_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.cancel_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -637,28 +918,36 @@ def create_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, data_labeling_job]): + has_flattened_params = any([parent, data_labeling_job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.CreateDataLabelingJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateDataLabelingJobRequest): + request = job_service.CreateDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if data_labeling_job is not None: - request.data_labeling_job = data_labeling_job + if parent is not None: + request.parent = parent + if data_labeling_job is not None: + request.data_labeling_job = data_labeling_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_data_labeling_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -706,27 +995,29 @@ def get_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.GetDataLabelingJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetDataLabelingJobRequest): + request = job_service.GetDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_data_labeling_job, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_data_labeling_job] # Certain fields should be provided within the metadata header; # add these here. @@ -780,27 +1071,29 @@ def list_data_labeling_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.ListDataLabelingJobsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListDataLabelingJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListDataLabelingJobsRequest): + request = job_service.ListDataLabelingJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_data_labeling_jobs, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_data_labeling_jobs] # Certain fields should be provided within the metadata header; # add these here. @@ -814,7 +1107,7 @@ def list_data_labeling_jobs( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListDataLabelingJobsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -874,26 +1167,34 @@ def delete_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.DeleteDataLabelingJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteDataLabelingJobRequest): + request = job_service.DeleteDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_data_labeling_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -943,26 +1244,34 @@ def cancel_data_labeling_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.CancelDataLabelingJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelDataLabelingJobRequest): + request = job_service.CancelDataLabelingJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_data_labeling_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.cancel_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -1017,28 +1326,38 @@ def create_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, hyperparameter_tuning_job]): + has_flattened_params = any([parent, hyperparameter_tuning_job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.CreateHyperparameterTuningJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateHyperparameterTuningJobRequest): + request = job_service.CreateHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if hyperparameter_tuning_job is not None: - request.hyperparameter_tuning_job = hyperparameter_tuning_job + if parent is not None: + request.parent = parent + if hyperparameter_tuning_job is not None: + request.hyperparameter_tuning_job = hyperparameter_tuning_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_hyperparameter_tuning_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[ + self._transport.create_hyperparameter_tuning_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -1088,27 +1407,31 @@ def get_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.GetHyperparameterTuningJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetHyperparameterTuningJobRequest): + request = job_service.GetHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_hyperparameter_tuning_job, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[ + self._transport.get_hyperparameter_tuning_job + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1163,27 +1486,31 @@ def list_hyperparameter_tuning_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.ListHyperparameterTuningJobsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListHyperparameterTuningJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListHyperparameterTuningJobsRequest): + request = job_service.ListHyperparameterTuningJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_hyperparameter_tuning_jobs, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[ + self._transport.list_hyperparameter_tuning_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1197,7 +1524,7 @@ def list_hyperparameter_tuning_jobs( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListHyperparameterTuningJobsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -1257,26 +1584,36 @@ def delete_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.DeleteHyperparameterTuningJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteHyperparameterTuningJobRequest): + request = job_service.DeleteHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_hyperparameter_tuning_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[ + self._transport.delete_hyperparameter_tuning_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -1339,26 +1676,36 @@ def cancel_hyperparameter_tuning_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.CancelHyperparameterTuningJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelHyperparameterTuningJobRequest): + request = job_service.CancelHyperparameterTuningJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_hyperparameter_tuning_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[ + self._transport.cancel_hyperparameter_tuning_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -1417,28 +1764,38 @@ def create_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, batch_prediction_job]): + has_flattened_params = any([parent, batch_prediction_job]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.CreateBatchPredictionJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateBatchPredictionJobRequest): + request = job_service.CreateBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if batch_prediction_job is not None: - request.batch_prediction_job = batch_prediction_job + if parent is not None: + request.parent = parent + if batch_prediction_job is not None: + request.batch_prediction_job = batch_prediction_job # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_batch_prediction_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[ + self._transport.create_batch_prediction_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -1491,27 +1848,29 @@ def get_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.GetBatchPredictionJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetBatchPredictionJobRequest): + request = job_service.GetBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_batch_prediction_job, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_batch_prediction_job] # Certain fields should be provided within the metadata header; # add these here. @@ -1566,27 +1925,31 @@ def list_batch_prediction_jobs( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.ListBatchPredictionJobsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListBatchPredictionJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListBatchPredictionJobsRequest): + request = job_service.ListBatchPredictionJobsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_batch_prediction_jobs, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[ + self._transport.list_batch_prediction_jobs + ] # Certain fields should be provided within the metadata header; # add these here. @@ -1600,7 +1963,7 @@ def list_batch_prediction_jobs( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListBatchPredictionJobsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -1661,26 +2024,36 @@ def delete_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.DeleteBatchPredictionJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteBatchPredictionJobRequest): + request = job_service.DeleteBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_batch_prediction_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[ + self._transport.delete_batch_prediction_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -1741,26 +2114,36 @@ def cancel_batch_prediction_job( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = job_service.CancelBatchPredictionJobRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelBatchPredictionJobRequest): + request = job_service.CancelBatchPredictionJobRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_batch_prediction_job, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[ + self._transport.cancel_batch_prediction_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -1770,13 +2153,13 @@ def cancel_batch_prediction_job( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index 0dd74763fa..05e5be73ca 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import batch_prediction_job from google.cloud.aiplatform_v1beta1.types import custom_job @@ -44,11 +44,11 @@ class ListCustomJobsPager: def __init__( self, - method: Callable[ - [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse - ], + method: Callable[..., job_service.ListCustomJobsResponse], request: job_service.ListCustomJobsRequest, response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -59,10 +59,13 @@ def __init__( The initial request object. response (:class:`~.job_service.ListCustomJobsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = job_service.ListCustomJobsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -72,7 +75,7 @@ def pages(self) -> Iterable[job_service.ListCustomJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[custom_job.CustomJob]: @@ -83,6 +86,72 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) +class ListCustomJobsAsyncPager: + """A pager for iterating through ``list_custom_jobs`` requests. + + This class thinly wraps an initial + :class:`~.job_service.ListCustomJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``custom_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListCustomJobs`` requests and continue to iterate + through the ``custom_jobs`` field on the + corresponding responses. + + All the usual :class:`~.job_service.ListCustomJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.job_service.ListCustomJobsRequest`): + The initial request object. + response (:class:`~.job_service.ListCustomJobsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListCustomJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListCustomJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[custom_job.CustomJob]: + async def async_generator(): + async for page in self.pages: + for response in page.custom_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + class ListDataLabelingJobsPager: """A pager for iterating through ``list_data_labeling_jobs`` requests. @@ -103,12 +172,11 @@ class ListDataLabelingJobsPager: def __init__( self, - method: Callable[ - [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse, - ], + method: Callable[..., job_service.ListDataLabelingJobsResponse], request: job_service.ListDataLabelingJobsRequest, response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -119,10 +187,13 @@ def __init__( The initial request object. response (:class:`~.job_service.ListDataLabelingJobsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = job_service.ListDataLabelingJobsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -132,7 +203,7 @@ def pages(self) -> Iterable[job_service.ListDataLabelingJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: @@ -143,6 +214,72 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) +class ListDataLabelingJobsAsyncPager: + """A pager for iterating through ``list_data_labeling_jobs`` requests. + + This class thinly wraps an initial + :class:`~.job_service.ListDataLabelingJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``data_labeling_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDataLabelingJobs`` requests and continue to iterate + through the ``data_labeling_jobs`` field on the + corresponding responses. + + All the usual :class:`~.job_service.ListDataLabelingJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.job_service.ListDataLabelingJobsRequest`): + The initial request object. + response (:class:`~.job_service.ListDataLabelingJobsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListDataLabelingJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListDataLabelingJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[data_labeling_job.DataLabelingJob]: + async def async_generator(): + async for page in self.pages: + for response in page.data_labeling_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + class ListHyperparameterTuningJobsPager: """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. @@ -163,12 +300,11 @@ class ListHyperparameterTuningJobsPager: def __init__( self, - method: Callable[ - [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse, - ], + method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], request: job_service.ListHyperparameterTuningJobsRequest, response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -179,10 +315,13 @@ def __init__( The initial request object. response (:class:`~.job_service.ListHyperparameterTuningJobsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = job_service.ListHyperparameterTuningJobsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -192,7 +331,7 @@ def pages(self) -> Iterable[job_service.ListHyperparameterTuningJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob]: @@ -203,6 +342,78 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) +class ListHyperparameterTuningJobsAsyncPager: + """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. + + This class thinly wraps an initial + :class:`~.job_service.ListHyperparameterTuningJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``hyperparameter_tuning_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListHyperparameterTuningJobs`` requests and continue to iterate + through the ``hyperparameter_tuning_jobs`` field on the + corresponding responses. + + All the usual :class:`~.job_service.ListHyperparameterTuningJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[job_service.ListHyperparameterTuningJobsResponse] + ], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): + The initial request object. + response (:class:`~.job_service.ListHyperparameterTuningJobsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListHyperparameterTuningJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__( + self, + ) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + async def async_generator(): + async for page in self.pages: + for response in page.hyperparameter_tuning_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + class ListBatchPredictionJobsPager: """A pager for iterating through ``list_batch_prediction_jobs`` requests. @@ -223,12 +434,11 @@ class ListBatchPredictionJobsPager: def __init__( self, - method: Callable[ - [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse, - ], + method: Callable[..., job_service.ListBatchPredictionJobsResponse], request: job_service.ListBatchPredictionJobsRequest, response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -239,10 +449,13 @@ def __init__( The initial request object. response (:class:`~.job_service.ListBatchPredictionJobsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = job_service.ListBatchPredictionJobsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -252,7 +465,7 @@ def pages(self) -> Iterable[job_service.ListBatchPredictionJobsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: @@ -261,3 +474,69 @@ def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListBatchPredictionJobsAsyncPager: + """A pager for iterating through ``list_batch_prediction_jobs`` requests. + + This class thinly wraps an initial + :class:`~.job_service.ListBatchPredictionJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``batch_prediction_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListBatchPredictionJobs`` requests and continue to iterate + through the ``batch_prediction_jobs`` field on the + corresponding responses. + + All the usual :class:`~.job_service.ListBatchPredictionJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.job_service.ListBatchPredictionJobsRequest`): + The initial request object. + response (:class:`~.job_service.ListBatchPredictionJobsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListBatchPredictionJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListBatchPredictionJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[batch_prediction_job.BatchPredictionJob]: + async def async_generator(): + async for page in self.pages: + for response in page.batch_prediction_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py index 2f081266a0..ca4d929cb5 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import JobServiceTransport from .grpc import JobServiceGrpcTransport +from .grpc_asyncio import JobServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] _transport_registry["grpc"] = JobServiceGrpcTransport +_transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport __all__ = ( "JobServiceTransport", "JobServiceGrpcTransport", + "JobServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index 6e11bb87ea..04c05890bc 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -41,7 +45,17 @@ from google.protobuf import empty_pb2 as empty # type: ignore -class JobServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class JobServiceTransport(abc.ABC): """Abstract transport class for JobService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -51,6 +65,11 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, ) -> None: """Instantiate the transport. @@ -61,6 +80,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -69,174 +99,338 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_custom_job: gapic_v1.method.wrap_method( + self.create_custom_job, default_timeout=None, client_info=client_info, + ), + self.get_custom_job: gapic_v1.method.wrap_method( + self.get_custom_job, default_timeout=None, client_info=client_info, + ), + self.list_custom_jobs: gapic_v1.method.wrap_method( + self.list_custom_jobs, default_timeout=None, client_info=client_info, + ), + self.delete_custom_job: gapic_v1.method.wrap_method( + self.delete_custom_job, default_timeout=None, client_info=client_info, + ), + self.cancel_custom_job: gapic_v1.method.wrap_method( + self.cancel_custom_job, default_timeout=None, client_info=client_info, + ), + self.create_data_labeling_job: gapic_v1.method.wrap_method( + self.create_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.get_data_labeling_job: gapic_v1.method.wrap_method( + self.get_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.list_data_labeling_jobs: gapic_v1.method.wrap_method( + self.list_data_labeling_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_data_labeling_job: gapic_v1.method.wrap_method( + self.delete_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_data_labeling_job: gapic_v1.method.wrap_method( + self.cancel_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.create_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.create_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.get_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.get_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.list_hyperparameter_tuning_jobs: gapic_v1.method.wrap_method( + self.list_hyperparameter_tuning_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.delete_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.cancel_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.create_batch_prediction_job: gapic_v1.method.wrap_method( + self.create_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.get_batch_prediction_job: gapic_v1.method.wrap_method( + self.get_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.list_batch_prediction_jobs: gapic_v1.method.wrap_method( + self.list_batch_prediction_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_batch_prediction_job: gapic_v1.method.wrap_method( + self.delete_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_batch_prediction_job: gapic_v1.method.wrap_method( + self.cancel_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property def create_custom_job( self, ) -> typing.Callable[ - [job_service.CreateCustomJobRequest], gca_custom_job.CustomJob + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_custom_job( self, - ) -> typing.Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: - raise NotImplementedError + ) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], + ]: + raise NotImplementedError() @property def list_custom_jobs( self, ) -> typing.Callable[ - [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def delete_custom_job( self, - ) -> typing.Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def cancel_custom_job( self, - ) -> typing.Callable[[job_service.CancelCustomJobRequest], empty.Empty]: - raise NotImplementedError + ) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() @property def create_data_labeling_job( self, ) -> typing.Callable[ [job_service.CreateDataLabelingJobRequest], - gca_data_labeling_job.DataLabelingJob, + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_data_labeling_job( self, ) -> typing.Callable[ - [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def list_data_labeling_jobs( self, ) -> typing.Callable[ [job_service.ListDataLabelingJobsRequest], - job_service.ListDataLabelingJobsResponse, + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def delete_data_labeling_job( self, ) -> typing.Callable[ - [job_service.DeleteDataLabelingJobRequest], operations.Operation + [job_service.DeleteDataLabelingJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def cancel_data_labeling_job( self, - ) -> typing.Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: - raise NotImplementedError + ) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() @property def create_hyperparameter_tuning_job( self, ) -> typing.Callable[ [job_service.CreateHyperparameterTuningJobRequest], - gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_hyperparameter_tuning_job( self, ) -> typing.Callable[ [job_service.GetHyperparameterTuningJobRequest], - hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def list_hyperparameter_tuning_jobs( self, ) -> typing.Callable[ [job_service.ListHyperparameterTuningJobsRequest], - job_service.ListHyperparameterTuningJobsResponse, + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def delete_hyperparameter_tuning_job( self, ) -> typing.Callable[ - [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def cancel_hyperparameter_tuning_job( self, ) -> typing.Callable[ - [job_service.CancelHyperparameterTuningJobRequest], empty.Empty + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], ]: - raise NotImplementedError + raise NotImplementedError() @property def create_batch_prediction_job( self, ) -> typing.Callable[ [job_service.CreateBatchPredictionJobRequest], - gca_batch_prediction_job.BatchPredictionJob, + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_batch_prediction_job( self, ) -> typing.Callable[ [job_service.GetBatchPredictionJobRequest], - batch_prediction_job.BatchPredictionJob, + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def list_batch_prediction_jobs( self, ) -> typing.Callable[ [job_service.ListBatchPredictionJobsRequest], - job_service.ListBatchPredictionJobsResponse, + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def delete_batch_prediction_job( self, ) -> typing.Callable[ - [job_service.DeleteBatchPredictionJobRequest], operations.Operation + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def cancel_batch_prediction_job( self, - ) -> typing.Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: - raise NotImplementedError + ) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() __all__ = ("JobServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index a598c180cf..f4b610bd53 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -41,7 +45,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import JobServiceTransport +from .base import JobServiceTransport, DEFAULT_CLIENT_INFO class JobServiceGrpcTransport(JobServiceTransport): @@ -57,12 +61,21 @@ class JobServiceGrpcTransport(JobServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - channel: grpc.Channel = None + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -74,28 +87,123 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. + self._ssl_channel_credentials = ssl_channel_credentials + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - **kwargs + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -105,30 +213,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..83cc826484 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -0,0 +1,889 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1beta1.types import custom_job +from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1beta1.types import job_service +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import JobServiceGrpcTransport + + +class JobServiceGrpcAsyncIOTransport(JobServiceTransport): + """gRPC AsyncIO backend transport for JobService. + + A service for creating and managing AI Platform's jobs. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def create_custom_job( + self, + ) -> Callable[ + [job_service.CreateCustomJobRequest], Awaitable[gca_custom_job.CustomJob] + ]: + r"""Return a callable for the create custom job method over gRPC. + + Creates a CustomJob. A created CustomJob right away + will be attempted to be run. + + Returns: + Callable[[~.CreateCustomJobRequest], + Awaitable[~.CustomJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateCustomJob", + request_serializer=job_service.CreateCustomJobRequest.serialize, + response_deserializer=gca_custom_job.CustomJob.deserialize, + ) + return self._stubs["create_custom_job"] + + @property + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], Awaitable[custom_job.CustomJob]]: + r"""Return a callable for the get custom job method over gRPC. + + Gets a CustomJob. + + Returns: + Callable[[~.GetCustomJobRequest], + Awaitable[~.CustomJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetCustomJob", + request_serializer=job_service.GetCustomJobRequest.serialize, + response_deserializer=custom_job.CustomJob.deserialize, + ) + return self._stubs["get_custom_job"] + + @property + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], + Awaitable[job_service.ListCustomJobsResponse], + ]: + r"""Return a callable for the list custom jobs method over gRPC. + + Lists CustomJobs in a Location. + + Returns: + Callable[[~.ListCustomJobsRequest], + Awaitable[~.ListCustomJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListCustomJobs", + request_serializer=job_service.ListCustomJobsRequest.serialize, + response_deserializer=job_service.ListCustomJobsResponse.deserialize, + ) + return self._stubs["list_custom_jobs"] + + @property + def delete_custom_job( + self, + ) -> Callable[ + [job_service.DeleteCustomJobRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete custom job method over gRPC. + + Deletes a CustomJob. + + Returns: + Callable[[~.DeleteCustomJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteCustomJob", + request_serializer=job_service.DeleteCustomJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_custom_job"] + + @property + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], Awaitable[empty.Empty]]: + r"""Return a callable for the cancel custom job method over gRPC. + + Cancels a CustomJob. Starts asynchronous cancellation on the + CustomJob. The server makes a best effort to cancel the job, but + success is not guaranteed. Clients can use + ``JobService.GetCustomJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the CustomJob is not deleted; instead it becomes a + job with a + ``CustomJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``CustomJob.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelCustomJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelCustomJob", + request_serializer=job_service.CancelCustomJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_custom_job"] + + @property + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + Awaitable[gca_data_labeling_job.DataLabelingJob], + ]: + r"""Return a callable for the create data labeling job method over gRPC. + + Creates a DataLabelingJob. + + Returns: + Callable[[~.CreateDataLabelingJobRequest], + Awaitable[~.DataLabelingJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateDataLabelingJob", + request_serializer=job_service.CreateDataLabelingJobRequest.serialize, + response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, + ) + return self._stubs["create_data_labeling_job"] + + @property + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], + Awaitable[data_labeling_job.DataLabelingJob], + ]: + r"""Return a callable for the get data labeling job method over gRPC. + + Gets a DataLabelingJob. + + Returns: + Callable[[~.GetDataLabelingJobRequest], + Awaitable[~.DataLabelingJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetDataLabelingJob", + request_serializer=job_service.GetDataLabelingJobRequest.serialize, + response_deserializer=data_labeling_job.DataLabelingJob.deserialize, + ) + return self._stubs["get_data_labeling_job"] + + @property + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + Awaitable[job_service.ListDataLabelingJobsResponse], + ]: + r"""Return a callable for the list data labeling jobs method over gRPC. + + Lists DataLabelingJobs in a Location. + + Returns: + Callable[[~.ListDataLabelingJobsRequest], + Awaitable[~.ListDataLabelingJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListDataLabelingJobs", + request_serializer=job_service.ListDataLabelingJobsRequest.serialize, + response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, + ) + return self._stubs["list_data_labeling_jobs"] + + @property + def delete_data_labeling_job( + self, + ) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete data labeling job method over gRPC. + + Deletes a DataLabelingJob. + + Returns: + Callable[[~.DeleteDataLabelingJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteDataLabelingJob", + request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_data_labeling_job"] + + @property + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], Awaitable[empty.Empty]]: + r"""Return a callable for the cancel data labeling job method over gRPC. + + Cancels a DataLabelingJob. Success of cancellation is + not guaranteed. + + Returns: + Callable[[~.CancelDataLabelingJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelDataLabelingJob", + request_serializer=job_service.CancelDataLabelingJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_data_labeling_job"] + + @property + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ]: + r"""Return a callable for the create hyperparameter tuning + job method over gRPC. + + Creates a HyperparameterTuningJob + + Returns: + Callable[[~.CreateHyperparameterTuningJobRequest], + Awaitable[~.HyperparameterTuningJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateHyperparameterTuningJob", + request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, + response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, + ) + return self._stubs["create_hyperparameter_tuning_job"] + + @property + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ]: + r"""Return a callable for the get hyperparameter tuning job method over gRPC. + + Gets a HyperparameterTuningJob + + Returns: + Callable[[~.GetHyperparameterTuningJobRequest], + Awaitable[~.HyperparameterTuningJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetHyperparameterTuningJob", + request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, + response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, + ) + return self._stubs["get_hyperparameter_tuning_job"] + + @property + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ]: + r"""Return a callable for the list hyperparameter tuning + jobs method over gRPC. + + Lists HyperparameterTuningJobs in a Location. + + Returns: + Callable[[~.ListHyperparameterTuningJobsRequest], + Awaitable[~.ListHyperparameterTuningJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListHyperparameterTuningJobs", + request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, + response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, + ) + return self._stubs["list_hyperparameter_tuning_jobs"] + + @property + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete hyperparameter tuning + job method over gRPC. + + Deletes a HyperparameterTuningJob. + + Returns: + Callable[[~.DeleteHyperparameterTuningJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteHyperparameterTuningJob", + request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_hyperparameter_tuning_job"] + + @property + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], Awaitable[empty.Empty] + ]: + r"""Return a callable for the cancel hyperparameter tuning + job method over gRPC. + + Cancels a HyperparameterTuningJob. Starts asynchronous + cancellation on the HyperparameterTuningJob. The server makes a + best effort to cancel the job, but success is not guaranteed. + Clients can use + ``JobService.GetHyperparameterTuningJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the HyperparameterTuningJob is not deleted; + instead it becomes a job with a + ``HyperparameterTuningJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``HyperparameterTuningJob.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelHyperparameterTuningJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelHyperparameterTuningJob", + request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_hyperparameter_tuning_job"] + + @property + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ]: + r"""Return a callable for the create batch prediction job method over gRPC. + + Creates a BatchPredictionJob. A BatchPredictionJob + once created will right away be attempted to start. + + Returns: + Callable[[~.CreateBatchPredictionJobRequest], + Awaitable[~.BatchPredictionJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CreateBatchPredictionJob", + request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, + response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, + ) + return self._stubs["create_batch_prediction_job"] + + @property + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + Awaitable[batch_prediction_job.BatchPredictionJob], + ]: + r"""Return a callable for the get batch prediction job method over gRPC. + + Gets a BatchPredictionJob + + Returns: + Callable[[~.GetBatchPredictionJobRequest], + Awaitable[~.BatchPredictionJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/GetBatchPredictionJob", + request_serializer=job_service.GetBatchPredictionJobRequest.serialize, + response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, + ) + return self._stubs["get_batch_prediction_job"] + + @property + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + Awaitable[job_service.ListBatchPredictionJobsResponse], + ]: + r"""Return a callable for the list batch prediction jobs method over gRPC. + + Lists BatchPredictionJobs in a Location. + + Returns: + Callable[[~.ListBatchPredictionJobsRequest], + Awaitable[~.ListBatchPredictionJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/ListBatchPredictionJobs", + request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, + response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, + ) + return self._stubs["list_batch_prediction_jobs"] + + @property + def delete_batch_prediction_job( + self, + ) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete batch prediction job method over gRPC. + + Deletes a BatchPredictionJob. Can only be called on + jobs that already finished. + + Returns: + Callable[[~.DeleteBatchPredictionJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/DeleteBatchPredictionJob", + request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_batch_prediction_job"] + + @property + def cancel_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], Awaitable[empty.Empty] + ]: + r"""Return a callable for the cancel batch prediction job method over gRPC. + + Cancels a BatchPredictionJob. + + Starts asynchronous cancellation on the BatchPredictionJob. The + server makes the best effort to cancel the job, but success is + not guaranteed. Clients can use + ``JobService.GetBatchPredictionJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On a successful + cancellation, the BatchPredictionJob is not deleted;instead its + ``BatchPredictionJob.state`` + is set to ``CANCELLED``. Any files already outputted by the job + are not deleted. + + Returns: + Callable[[~.CancelBatchPredictionJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.JobService/CancelBatchPredictionJob", + request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_batch_prediction_job"] + + +__all__ = ("JobServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py new file mode 100644 index 0000000000..1d6216d1f7 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import MigrationServiceClient +from .async_client import MigrationServiceAsyncClient + +__all__ = ( + "MigrationServiceClient", + "MigrationServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py new file mode 100644 index 0000000000..af13c4d4fb --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -0,0 +1,363 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.migration_service import pagers +from google.cloud.aiplatform_v1beta1.types import migratable_resource +from google.cloud.aiplatform_v1beta1.types import migration_service + +from .transports.base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import MigrationServiceGrpcAsyncIOTransport +from .client import MigrationServiceClient + + +class MigrationServiceAsyncClient: + """A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + """ + + _client: MigrationServiceClient + + DEFAULT_ENDPOINT = MigrationServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = MigrationServiceClient.DEFAULT_MTLS_ENDPOINT + + annotated_dataset_path = staticmethod(MigrationServiceClient.annotated_dataset_path) + parse_annotated_dataset_path = staticmethod( + MigrationServiceClient.parse_annotated_dataset_path + ) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + model_path = staticmethod(MigrationServiceClient.model_path) + parse_model_path = staticmethod(MigrationServiceClient.parse_model_path) + model_path = staticmethod(MigrationServiceClient.model_path) + parse_model_path = staticmethod(MigrationServiceClient.parse_model_path) + version_path = staticmethod(MigrationServiceClient.version_path) + parse_version_path = staticmethod(MigrationServiceClient.parse_version_path) + + common_billing_account_path = staticmethod( + MigrationServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + MigrationServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(MigrationServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + MigrationServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + MigrationServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + MigrationServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(MigrationServiceClient.common_project_path) + parse_common_project_path = staticmethod( + MigrationServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(MigrationServiceClient.common_location_path) + parse_common_location_path = staticmethod( + MigrationServiceClient.parse_common_location_path + ) + + from_service_account_file = MigrationServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> MigrationServiceTransport: + """Return the transport used by the client instance. + + Returns: + MigrationServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the migration service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.MigrationServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = MigrationServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesAsyncPager: + r"""Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Args: + request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + The request object. Request message for + ``MigrationService.SearchMigratableResources``. + parent (:class:`str`): + Required. The location that the migratable resources + should be searched from. It's the AI Platform location + that the resources can be migrated to, not the + resources' original location. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.SearchMigratableResourcesAsyncPager: + Response message for + ``MigrationService.SearchMigratableResources``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = migration_service.SearchMigratableResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.search_migratable_resources, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.SearchMigratableResourcesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Args: + request (:class:`~.migration_service.BatchMigrateResourcesRequest`): + The request object. Request message for + ``MigrationService.BatchMigrateResources``. + parent (:class:`str`): + Required. The location of the migrated resource will + live in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + migrate_resource_requests (:class:`Sequence[~.migration_service.MigrateResourceRequest]`): + Required. The request messages + specifying the resources to migrate. + They must be in the same location as the + destination. Up to 50 resources can be + migrated in one batch. + This corresponds to the ``migrate_resource_requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.migration_service.BatchMigrateResourcesResponse`: + Response message for + ``MigrationService.BatchMigrateResources``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, migrate_resource_requests]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = migration_service.BatchMigrateResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + if migrate_resource_requests: + request.migrate_resource_requests.extend(migrate_resource_requests) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.batch_migrate_resources, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + migration_service.BatchMigrateResourcesResponse, + metadata_type=migration_service.BatchMigrateResourcesOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("MigrationServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py new file mode 100644 index 0000000000..bf1f8e5c6b --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -0,0 +1,639 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.migration_service import pagers +from google.cloud.aiplatform_v1beta1.types import migratable_resource +from google.cloud.aiplatform_v1beta1.types import migration_service + +from .transports.base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import MigrationServiceGrpcTransport +from .transports.grpc_asyncio import MigrationServiceGrpcAsyncIOTransport + + +class MigrationServiceClientMeta(type): + """Metaclass for the MigrationService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry["grpc"] = MigrationServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class MigrationServiceClient(metaclass=MigrationServiceClientMeta): + """A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + {@api.name}: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> MigrationServiceTransport: + """Return the transport used by the client instance. + + Returns: + MigrationServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def annotated_dataset_path( + project: str, dataset: str, annotated_dataset: str, + ) -> str: + """Return a fully-qualified annotated_dataset string.""" + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + + @staticmethod + def parse_annotated_dataset_path(path: str) -> Dict[str, str]: + """Parse a annotated_dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, location: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, location: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def version_path(project: str, model: str, version: str,) -> str: + """Return a fully-qualified version string.""" + return "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) + + @staticmethod + def parse_version_path(path: str) -> Dict[str, str]: + """Parse a version path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the migration service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.MigrationServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, MigrationServiceTransport): + # transport is a MigrationServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: + r"""Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Args: + request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + The request object. Request message for + ``MigrationService.SearchMigratableResources``. + parent (:class:`str`): + Required. The location that the migratable resources + should be searched from. It's the AI Platform location + that the resources can be migrated to, not the + resources' original location. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.SearchMigratableResourcesPager: + Response message for + ``MigrationService.SearchMigratableResources``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a migration_service.SearchMigratableResourcesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, migration_service.SearchMigratableResourcesRequest): + request = migration_service.SearchMigratableResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.search_migratable_resources + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.SearchMigratableResourcesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Args: + request (:class:`~.migration_service.BatchMigrateResourcesRequest`): + The request object. Request message for + ``MigrationService.BatchMigrateResources``. + parent (:class:`str`): + Required. The location of the migrated resource will + live in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + migrate_resource_requests (:class:`Sequence[~.migration_service.MigrateResourceRequest]`): + Required. The request messages + specifying the resources to migrate. + They must be in the same location as the + destination. Up to 50 resources can be + migrated in one batch. + This corresponds to the ``migrate_resource_requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.migration_service.BatchMigrateResourcesResponse`: + Response message for + ``MigrationService.BatchMigrateResources``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, migrate_resource_requests]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a migration_service.BatchMigrateResourcesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, migration_service.BatchMigrateResourcesRequest): + request = migration_service.BatchMigrateResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + if migrate_resource_requests: + request.migrate_resource_requests.extend(migrate_resource_requests) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_migrate_resources] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + migration_service.BatchMigrateResourcesResponse, + metadata_type=migration_service.BatchMigrateResourcesOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("MigrationServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py new file mode 100644 index 0000000000..826325f2e4 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1beta1.types import migratable_resource +from google.cloud.aiplatform_v1beta1.types import migration_service + + +class SearchMigratableResourcesPager: + """A pager for iterating through ``search_migratable_resources`` requests. + + This class thinly wraps an initial + :class:`~.migration_service.SearchMigratableResourcesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``migratable_resources`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``SearchMigratableResources`` requests and continue to iterate + through the ``migratable_resources`` field on the + corresponding responses. + + All the usual :class:`~.migration_service.SearchMigratableResourcesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., migration_service.SearchMigratableResourcesResponse], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + The initial request object. + response (:class:`~.migration_service.SearchMigratableResourcesResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = migration_service.SearchMigratableResourcesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[migration_service.SearchMigratableResourcesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[migratable_resource.MigratableResource]: + for page in self.pages: + yield from page.migratable_resources + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class SearchMigratableResourcesAsyncPager: + """A pager for iterating through ``search_migratable_resources`` requests. + + This class thinly wraps an initial + :class:`~.migration_service.SearchMigratableResourcesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``migratable_resources`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``SearchMigratableResources`` requests and continue to iterate + through the ``migratable_resources`` field on the + corresponding responses. + + All the usual :class:`~.migration_service.SearchMigratableResourcesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[migration_service.SearchMigratableResourcesResponse] + ], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + The initial request object. + response (:class:`~.migration_service.SearchMigratableResourcesResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = migration_service.SearchMigratableResourcesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[migratable_resource.MigratableResource]: + async def async_generator(): + async for page in self.pages: + for response in page.migratable_resources: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py new file mode 100644 index 0000000000..af727857e7 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import MigrationServiceTransport +from .grpc import MigrationServiceGrpcTransport +from .grpc_asyncio import MigrationServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] +_transport_registry["grpc"] = MigrationServiceGrpcTransport +_transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + + +__all__ = ( + "MigrationServiceTransport", + "MigrationServiceGrpcTransport", + "MigrationServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py new file mode 100644 index 0000000000..cbcb288489 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1beta1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class MigrationServiceTransport(abc.ABC): + """Abstract transport class for MigrationService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.search_migratable_resources: gapic_v1.method.wrap_method( + self.search_migratable_resources, + default_timeout=None, + client_info=client_info, + ), + self.batch_migrate_resources: gapic_v1.method.wrap_method( + self.batch_migrate_resources, + default_timeout=None, + client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def search_migratable_resources( + self, + ) -> typing.Callable[ + [migration_service.SearchMigratableResourcesRequest], + typing.Union[ + migration_service.SearchMigratableResourcesResponse, + typing.Awaitable[migration_service.SearchMigratableResourcesResponse], + ], + ]: + raise NotImplementedError() + + @property + def batch_migrate_resources( + self, + ) -> typing.Callable[ + [migration_service.BatchMigrateResourcesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + +__all__ = ("MigrationServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py new file mode 100644 index 0000000000..efd4c4b6a4 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1beta1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import MigrationServiceTransport, DEFAULT_CLIENT_INFO + + +class MigrationServiceGrpcTransport(MigrationServiceTransport): + """gRPC backend transport for MigrationService. + + A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + self._stubs = {} # type: Dict[str, Callable] + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optionsl[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + migration_service.SearchMigratableResourcesResponse, + ]: + r"""Return a callable for the search migratable resources method over gRPC. + + Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Returns: + Callable[[~.SearchMigratableResourcesRequest], + ~.SearchMigratableResourcesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources", + request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, + response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, + ) + return self._stubs["search_migratable_resources"] + + @property + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], operations.Operation + ]: + r"""Return a callable for the batch migrate resources method over gRPC. + + Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Returns: + Callable[[~.BatchMigrateResourcesRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources", + request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["batch_migrate_resources"] + + +__all__ = ("MigrationServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..ba038f57c5 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import MigrationServiceGrpcTransport + + +class MigrationServiceGrpcAsyncIOTransport(MigrationServiceTransport): + """gRPC AsyncIO backend transport for MigrationService. + + A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + Awaitable[migration_service.SearchMigratableResourcesResponse], + ]: + r"""Return a callable for the search migratable resources method over gRPC. + + Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Returns: + Callable[[~.SearchMigratableResourcesRequest], + Awaitable[~.SearchMigratableResourcesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/SearchMigratableResources", + request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, + response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, + ) + return self._stubs["search_migratable_resources"] + + @property + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the batch migrate resources method over gRPC. + + Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Returns: + Callable[[~.BatchMigrateResourcesRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.MigrationService/BatchMigrateResources", + request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["batch_migrate_resources"] + + +__all__ = ("MigrationServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py index b0d80fcc98..b39295ebfe 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/__init__.py @@ -16,5 +16,9 @@ # from .client import ModelServiceClient +from .async_client import ModelServiceAsyncClient -__all__ = ("ModelServiceClient",) +__all__ = ( + "ModelServiceClient", + "ModelServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py new file mode 100644 index 0000000000..81c1f9cb51 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -0,0 +1,1021 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.model_service import pagers +from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model as gca_model +from google.cloud.aiplatform_v1beta1.types import model_evaluation +from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice +from google.cloud.aiplatform_v1beta1.types import model_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .client import ModelServiceClient + + +class ModelServiceAsyncClient: + """A service for managing AI Platform's machine learning Models.""" + + _client: ModelServiceClient + + DEFAULT_ENDPOINT = ModelServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = ModelServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(ModelServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(ModelServiceClient.parse_endpoint_path) + model_path = staticmethod(ModelServiceClient.model_path) + parse_model_path = staticmethod(ModelServiceClient.parse_model_path) + model_evaluation_path = staticmethod(ModelServiceClient.model_evaluation_path) + parse_model_evaluation_path = staticmethod( + ModelServiceClient.parse_model_evaluation_path + ) + model_evaluation_slice_path = staticmethod( + ModelServiceClient.model_evaluation_slice_path + ) + parse_model_evaluation_slice_path = staticmethod( + ModelServiceClient.parse_model_evaluation_slice_path + ) + training_pipeline_path = staticmethod(ModelServiceClient.training_pipeline_path) + parse_training_pipeline_path = staticmethod( + ModelServiceClient.parse_training_pipeline_path + ) + + common_billing_account_path = staticmethod( + ModelServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ModelServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(ModelServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(ModelServiceClient.common_organization_path) + parse_common_organization_path = staticmethod( + ModelServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(ModelServiceClient.common_project_path) + parse_common_project_path = staticmethod( + ModelServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(ModelServiceClient.common_location_path) + parse_common_location_path = staticmethod( + ModelServiceClient.parse_common_location_path + ) + + from_service_account_file = ModelServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ModelServiceTransport: + """Return the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(ModelServiceClient).get_transport_class, type(ModelServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.ModelServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = ModelServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Uploads a Model artifact into AI Platform. + + Args: + request (:class:`~.model_service.UploadModelRequest`): + The request object. Request message for + ``ModelService.UploadModel``. + parent (:class:`str`): + Required. The resource name of the Location into which + to upload the Model. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + model (:class:`~.gca_model.Model`): + Required. The Model to create. + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.model_service.UploadModelResponse`: Response + message of + ``ModelService.UploadModel`` + operation. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, model]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.UploadModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if model is not None: + request.model = model + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.upload_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + model_service.UploadModelResponse, + metadata_type=model_service.UploadModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: + r"""Gets a Model. + + Args: + request (:class:`~.model_service.GetModelRequest`): + The request object. Request message for + ``ModelService.GetModel``. + name (:class:`str`): + Required. The name of the Model resource. Format: + ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model.Model: + A trained machine learning Model. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.GetModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: + r"""Lists Models in a Location. + + Args: + request (:class:`~.model_service.ListModelsRequest`): + The request object. Request message for + ``ModelService.ListModels``. + parent (:class:`str`): + Required. The resource name of the Location to list the + Models from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListModelsAsyncPager: + Response message for + ``ModelService.ListModels`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.ListModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_models, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: + r"""Updates a Model. + + Args: + request (:class:`~.model_service.UpdateModelRequest`): + The request object. Request message for + ``ModelService.UpdateModel``. + model (:class:`~.gca_model.Model`): + Required. The Model which replaces + the resource on the server. + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`~.field_mask.FieldMask`): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + + [FieldMask](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#fieldmask). + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_model.Model: + A trained machine learning Model. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.UpdateModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model is not None: + request.model = model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a Model. + Note: Model can only be deleted if there are no + DeployedModels created from it. + + Args: + request (:class:`~.model_service.DeleteModelRequest`): + The request object. Request message for + ``ModelService.DeleteModel``. + name (:class:`str`): + Required. The name of the Model resource to be deleted. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.DeleteModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Exports a trained, exportable, Model to a location specified by + the user. A Model is considered to be exportable if it has at + least one [supported export + format][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. + + Args: + request (:class:`~.model_service.ExportModelRequest`): + The request object. Request message for + ``ModelService.ExportModel``. + name (:class:`str`): + Required. The resource name of the Model to export. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + output_config (:class:`~.model_service.ExportModelRequest.OutputConfig`): + Required. The desired output location + and configuration. + This corresponds to the ``output_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.model_service.ExportModelResponse`: Response + message of + ``ModelService.ExportModel`` + operation. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, output_config]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.ExportModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if output_config is not None: + request.output_config = output_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.export_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + model_service.ExportModelResponse, + metadata_type=model_service.ExportModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: + r"""Gets a ModelEvaluation. + + Args: + request (:class:`~.model_service.GetModelEvaluationRequest`): + The request object. Request message for + ``ModelService.GetModelEvaluation``. + name (:class:`str`): + Required. The name of the ModelEvaluation resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model_evaluation.ModelEvaluation: + A collection of metrics calculated by + comparing Model's predictions on all of + the test data against annotations from + the test data. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.GetModelEvaluationRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model_evaluation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsAsyncPager: + r"""Lists ModelEvaluations in a Model. + + Args: + request (:class:`~.model_service.ListModelEvaluationsRequest`): + The request object. Request message for + ``ModelService.ListModelEvaluations``. + parent (:class:`str`): + Required. The resource name of the Model to list the + ModelEvaluations from. Format: + ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListModelEvaluationsAsyncPager: + Response message for + ``ModelService.ListModelEvaluations``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.ListModelEvaluationsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_model_evaluations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelEvaluationsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: + r"""Gets a ModelEvaluationSlice. + + Args: + request (:class:`~.model_service.GetModelEvaluationSliceRequest`): + The request object. Request message for + ``ModelService.GetModelEvaluationSlice``. + name (:class:`str`): + Required. The name of the ModelEvaluationSlice resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model_evaluation_slice.ModelEvaluationSlice: + A collection of metrics calculated by + comparing Model's predictions on a slice + of the test data against ground truth + annotations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.GetModelEvaluationSliceRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model_evaluation_slice, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesAsyncPager: + r"""Lists ModelEvaluationSlices in a ModelEvaluation. + + Args: + request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): + The request object. Request message for + ``ModelService.ListModelEvaluationSlices``. + parent (:class:`str`): + Required. The resource name of the ModelEvaluation to + list the ModelEvaluationSlices from. Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListModelEvaluationSlicesAsyncPager: + Response message for + ``ModelService.ListModelEvaluationSlices``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.ListModelEvaluationSlicesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_model_evaluation_slices, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelEvaluationSlicesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("ModelServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index dab285be4c..30c00c0c9d 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.types import deployed_model_ref from google.cloud.aiplatform_v1beta1.types import explanation @@ -41,8 +48,9 @@ from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from .transports.base import ModelServiceTransport +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import ModelServiceGrpcTransport +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport class ModelServiceClientMeta(type): @@ -55,6 +63,7 @@ class ModelServiceClientMeta(type): _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: """Return an appropriate transport class. @@ -78,8 +87,38 @@ def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: class ModelServiceClient(metaclass=ModelServiceClientMeta): """A service for managing AI Platform's machine learning Models.""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -102,6 +141,31 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> ModelServiceTransport: + """Return the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str, location: str, endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str, str]: + """Parse a endpoint path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" @@ -109,12 +173,139 @@ def model_path(project: str, location: str, model: str,) -> str: project=project, location=location, model=model, ) + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_evaluation_path( + project: str, location: str, model: str, evaluation: str, + ) -> str: + """Return a fully-qualified model_evaluation string.""" + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) + + @staticmethod + def parse_model_evaluation_path(path: str) -> Dict[str, str]: + """Parse a model_evaluation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_evaluation_slice_path( + project: str, location: str, model: str, evaluation: str, slice: str, + ) -> str: + """Return a fully-qualified model_evaluation_slice string.""" + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) + + @staticmethod + def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: + """Parse a model_evaluation_slice path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: + """Return a fully-qualified training_pipeline string.""" + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + + @staticmethod + def parse_training_pipeline_path(path: str) -> Dict[str, str]: + """Parse a training_pipeline path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, ModelServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the model service client. @@ -127,26 +318,102 @@ def __init__( transport (Union[str, ~.ModelServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, ModelServiceTransport): - if credentials: + # transport is a ModelServiceTransport instance. + if credentials or client_options.credentials_file: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def upload_model( @@ -198,28 +465,36 @@ def upload_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, model]): + has_flattened_params = any([parent, model]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.UploadModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.UploadModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.UploadModelRequest): + request = model_service.UploadModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if model is not None: - request.model = model + if parent is not None: + request.parent = parent + if model is not None: + request.model = model # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.upload_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.upload_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -271,25 +546,29 @@ def get_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.GetModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_model, default_timeout=None, client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_model] # Certain fields should be provided within the metadata header; # add these here. @@ -344,25 +623,29 @@ def list_models( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.ListModelsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_models, default_timeout=None, client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_models] # Certain fields should be provided within the metadata header; # add these here. @@ -376,7 +659,7 @@ def list_models( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -426,28 +709,38 @@ def update_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([model, update_mask]): + has_flattened_params = any([model, update_mask]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.UpdateModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.UpdateModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.UpdateModelRequest): + request = model_service.UpdateModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if model is not None: - request.model = model - if update_mask is not None: - request.update_mask = update_mask + if model is not None: + request.model = model + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.update_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.update_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), ) # Send the request. @@ -511,26 +804,34 @@ def delete_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.DeleteModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.DeleteModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.DeleteModelRequest): + request = model_service.DeleteModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -600,28 +901,36 @@ def export_model( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name, output_config]): + has_flattened_params = any([name, output_config]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.ExportModelRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ExportModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ExportModelRequest): + request = model_service.ExportModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name - if output_config is not None: - request.output_config = output_config + if name is not None: + request.name = name + if output_config is not None: + request.output_config = output_config # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.export_model, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.export_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -679,27 +988,29 @@ def get_model_evaluation( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.GetModelEvaluationRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelEvaluationRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelEvaluationRequest): + request = model_service.GetModelEvaluationRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_model_evaluation, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation] # Certain fields should be provided within the metadata header; # add these here. @@ -754,27 +1065,29 @@ def list_model_evaluations( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.ListModelEvaluationsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelEvaluationsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelEvaluationsRequest): + request = model_service.ListModelEvaluationsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_model_evaluations, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_model_evaluations] # Certain fields should be provided within the metadata header; # add these here. @@ -788,7 +1101,7 @@ def list_model_evaluations( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -835,27 +1148,31 @@ def get_model_evaluation_slice( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.GetModelEvaluationSliceRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelEvaluationSliceRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelEvaluationSliceRequest): + request = model_service.GetModelEvaluationSliceRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_model_evaluation_slice, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[ + self._transport.get_model_evaluation_slice + ] # Certain fields should be provided within the metadata header; # add these here. @@ -911,27 +1228,31 @@ def list_model_evaluation_slices( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = model_service.ListModelEvaluationSlicesRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelEvaluationSlicesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelEvaluationSlicesRequest): + request = model_service.ListModelEvaluationSlicesRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_model_evaluation_slices, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[ + self._transport.list_model_evaluation_slices + ] # Certain fields should be provided within the metadata header; # add these here. @@ -945,7 +1266,7 @@ def list_model_evaluation_slices( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListModelEvaluationSlicesPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -953,13 +1274,13 @@ def list_model_evaluation_slices( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index 4169c27b85..1ab3aacb91 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model_evaluation @@ -43,11 +43,11 @@ class ListModelsPager: def __init__( self, - method: Callable[ - [model_service.ListModelsRequest], model_service.ListModelsResponse - ], + method: Callable[..., model_service.ListModelsResponse], request: model_service.ListModelsRequest, response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -58,10 +58,13 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -71,7 +74,7 @@ def pages(self) -> Iterable[model_service.ListModelsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[model.Model]: @@ -82,6 +85,72 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) +class ListModelsAsyncPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`~.model_service.ListModelsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`~.model_service.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.model_service.ListModelsRequest`): + The initial request object. + response (:class:`~.model_service.ListModelsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model.Model]: + async def async_generator(): + async for page in self.pages: + for response in page.models: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + class ListModelEvaluationsPager: """A pager for iterating through ``list_model_evaluations`` requests. @@ -102,12 +171,11 @@ class ListModelEvaluationsPager: def __init__( self, - method: Callable[ - [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse, - ], + method: Callable[..., model_service.ListModelEvaluationsResponse], request: model_service.ListModelEvaluationsRequest, response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -118,10 +186,13 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelEvaluationsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -131,7 +202,7 @@ def pages(self) -> Iterable[model_service.ListModelEvaluationsResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: @@ -142,6 +213,72 @@ def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) +class ListModelEvaluationsAsyncPager: + """A pager for iterating through ``list_model_evaluations`` requests. + + This class thinly wraps an initial + :class:`~.model_service.ListModelEvaluationsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``model_evaluations`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModelEvaluations`` requests and continue to iterate + through the ``model_evaluations`` field on the + corresponding responses. + + All the usual :class:`~.model_service.ListModelEvaluationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.model_service.ListModelEvaluationsRequest`): + The initial request object. + response (:class:`~.model_service.ListModelEvaluationsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelEvaluationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model_evaluation.ModelEvaluation]: + async def async_generator(): + async for page in self.pages: + for response in page.model_evaluations: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + class ListModelEvaluationSlicesPager: """A pager for iterating through ``list_model_evaluation_slices`` requests. @@ -162,12 +299,11 @@ class ListModelEvaluationSlicesPager: def __init__( self, - method: Callable[ - [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse, - ], + method: Callable[..., model_service.ListModelEvaluationSlicesResponse], request: model_service.ListModelEvaluationSlicesRequest, response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -178,10 +314,13 @@ def __init__( The initial request object. response (:class:`~.model_service.ListModelEvaluationSlicesResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = model_service.ListModelEvaluationSlicesRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -191,7 +330,7 @@ def pages(self) -> Iterable[model_service.ListModelEvaluationSlicesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: @@ -200,3 +339,73 @@ def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListModelEvaluationSlicesAsyncPager: + """A pager for iterating through ``list_model_evaluation_slices`` requests. + + This class thinly wraps an initial + :class:`~.model_service.ListModelEvaluationSlicesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``model_evaluation_slices`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModelEvaluationSlices`` requests and continue to iterate + through the ``model_evaluation_slices`` field on the + corresponding responses. + + All the usual :class:`~.model_service.ListModelEvaluationSlicesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[model_service.ListModelEvaluationSlicesResponse] + ], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): + The initial request object. + response (:class:`~.model_service.ListModelEvaluationSlicesResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelEvaluationSlicesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model_evaluation_slice.ModelEvaluationSlice]: + async def async_generator(): + async for page in self.pages: + for response in page.model_evaluation_slices: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py index 7bbcc75582..a521df9229 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import ModelServiceTransport from .grpc import ModelServiceGrpcTransport +from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] _transport_registry["grpc"] = ModelServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport __all__ = ( "ModelServiceTransport", "ModelServiceGrpcTransport", + "ModelServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index 53f94ea393..681d035178 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -30,7 +34,17 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class ModelServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -40,6 +54,11 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, ) -> None: """Instantiate the transport. @@ -50,6 +69,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -58,89 +88,179 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.upload_model: gapic_v1.method.wrap_method( + self.upload_model, default_timeout=None, client_info=client_info, + ), + self.get_model: gapic_v1.method.wrap_method( + self.get_model, default_timeout=None, client_info=client_info, + ), + self.list_models: gapic_v1.method.wrap_method( + self.list_models, default_timeout=None, client_info=client_info, + ), + self.update_model: gapic_v1.method.wrap_method( + self.update_model, default_timeout=None, client_info=client_info, + ), + self.delete_model: gapic_v1.method.wrap_method( + self.delete_model, default_timeout=None, client_info=client_info, + ), + self.export_model: gapic_v1.method.wrap_method( + self.export_model, default_timeout=None, client_info=client_info, + ), + self.get_model_evaluation: gapic_v1.method.wrap_method( + self.get_model_evaluation, + default_timeout=None, + client_info=client_info, + ), + self.list_model_evaluations: gapic_v1.method.wrap_method( + self.list_model_evaluations, + default_timeout=None, + client_info=client_info, + ), + self.get_model_evaluation_slice: gapic_v1.method.wrap_method( + self.get_model_evaluation_slice, + default_timeout=None, + client_info=client_info, + ), + self.list_model_evaluation_slices: gapic_v1.method.wrap_method( + self.list_model_evaluation_slices, + default_timeout=None, + client_info=client_info, + ), + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property def upload_model( self, - ) -> typing.Callable[[model_service.UploadModelRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def get_model( self, - ) -> typing.Callable[[model_service.GetModelRequest], model.Model]: - raise NotImplementedError + ) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[model.Model, typing.Awaitable[model.Model]], + ]: + raise NotImplementedError() @property def list_models( self, ) -> typing.Callable[ - [model_service.ListModelsRequest], model_service.ListModelsResponse + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def update_model( self, - ) -> typing.Callable[[model_service.UpdateModelRequest], gca_model.Model]: - raise NotImplementedError + ) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], + ]: + raise NotImplementedError() @property def delete_model( self, - ) -> typing.Callable[[model_service.DeleteModelRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def export_model( self, - ) -> typing.Callable[[model_service.ExportModelRequest], operations.Operation]: - raise NotImplementedError + ) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() @property def get_model_evaluation( self, ) -> typing.Callable[ - [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def list_model_evaluations( self, ) -> typing.Callable[ [model_service.ListModelEvaluationsRequest], - model_service.ListModelEvaluationsResponse, + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_model_evaluation_slice( self, ) -> typing.Callable[ [model_service.GetModelEvaluationSliceRequest], - model_evaluation_slice.ModelEvaluationSlice, + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def list_model_evaluation_slices( self, ) -> typing.Callable[ [model_service.ListModelEvaluationSlicesRequest], - model_service.ListModelEvaluationSlicesResponse, + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() __all__ = ("ModelServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index f83c41e879..442b665d3a 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -30,7 +34,7 @@ from google.cloud.aiplatform_v1beta1.types import model_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import ModelServiceTransport +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO class ModelServiceGrpcTransport(ModelServiceTransport): @@ -46,12 +50,21 @@ class ModelServiceGrpcTransport(ModelServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - channel: grpc.Channel = None + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -63,28 +76,123 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. + self._ssl_channel_credentials = ssl_channel_credentials + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - **kwargs + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -94,30 +202,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..13e9848290 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -0,0 +1,538 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model as gca_model +from google.cloud.aiplatform_v1beta1.types import model_evaluation +from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice +from google.cloud.aiplatform_v1beta1.types import model_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import ModelServiceGrpcTransport + + +class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): + """gRPC AsyncIO backend transport for ModelService. + + A service for managing AI Platform's machine learning Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the upload model method over gRPC. + + Uploads a Model artifact into AI Platform. + + Returns: + Callable[[~.UploadModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UploadModel", + request_serializer=model_service.UploadModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["upload_model"] + + @property + def get_model( + self, + ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: + r"""Return a callable for the get model method over gRPC. + + Gets a Model. + + Returns: + Callable[[~.GetModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModel", + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs["get_model"] + + @property + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] + ]: + r"""Return a callable for the list models method over gRPC. + + Lists Models in a Location. + + Returns: + Callable[[~.ListModelsRequest], + Awaitable[~.ListModelsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModels", + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs["list_models"] + + @property + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], Awaitable[gca_model.Model]]: + r"""Return a callable for the update model method over gRPC. + + Updates a Model. + + Returns: + Callable[[~.UpdateModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/UpdateModel", + request_serializer=model_service.UpdateModelRequest.serialize, + response_deserializer=gca_model.Model.deserialize, + ) + return self._stubs["update_model"] + + @property + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the delete model method over gRPC. + + Deletes a Model. + Note: Model can only be deleted if there are no + DeployedModels created from it. + + Returns: + Callable[[~.DeleteModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/DeleteModel", + request_serializer=model_service.DeleteModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_model"] + + @property + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the export model method over gRPC. + + Exports a trained, exportable, Model to a location specified by + the user. A Model is considered to be exportable if it has at + least one [supported export + format][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. + + Returns: + Callable[[~.ExportModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ExportModel", + request_serializer=model_service.ExportModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["export_model"] + + @property + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], + Awaitable[model_evaluation.ModelEvaluation], + ]: + r"""Return a callable for the get model evaluation method over gRPC. + + Gets a ModelEvaluation. + + Returns: + Callable[[~.GetModelEvaluationRequest], + Awaitable[~.ModelEvaluation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluation", + request_serializer=model_service.GetModelEvaluationRequest.serialize, + response_deserializer=model_evaluation.ModelEvaluation.deserialize, + ) + return self._stubs["get_model_evaluation"] + + @property + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + Awaitable[model_service.ListModelEvaluationsResponse], + ]: + r"""Return a callable for the list model evaluations method over gRPC. + + Lists ModelEvaluations in a Model. + + Returns: + Callable[[~.ListModelEvaluationsRequest], + Awaitable[~.ListModelEvaluationsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluations", + request_serializer=model_service.ListModelEvaluationsRequest.serialize, + response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, + ) + return self._stubs["list_model_evaluations"] + + @property + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ]: + r"""Return a callable for the get model evaluation slice method over gRPC. + + Gets a ModelEvaluationSlice. + + Returns: + Callable[[~.GetModelEvaluationSliceRequest], + Awaitable[~.ModelEvaluationSlice]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/GetModelEvaluationSlice", + request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, + response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, + ) + return self._stubs["get_model_evaluation_slice"] + + @property + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + Awaitable[model_service.ListModelEvaluationSlicesResponse], + ]: + r"""Return a callable for the list model evaluation slices method over gRPC. + + Lists ModelEvaluationSlices in a ModelEvaluation. + + Returns: + Callable[[~.ListModelEvaluationSlicesRequest], + Awaitable[~.ListModelEvaluationSlicesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.ModelService/ListModelEvaluationSlices", + request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, + response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, + ) + return self._stubs["list_model_evaluation_slices"] + + +__all__ = ("ModelServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py index 29ba95e15f..7f02b47358 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/__init__.py @@ -16,5 +16,9 @@ # from .client import PipelineServiceClient +from .async_client import PipelineServiceAsyncClient -__all__ = ("PipelineServiceClient",) +__all__ = ( + "PipelineServiceClient", + "PipelineServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py new file mode 100644 index 0000000000..d361b05e21 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -0,0 +1,596 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import pipeline_service +from google.cloud.aiplatform_v1beta1.types import pipeline_state +from google.cloud.aiplatform_v1beta1.types import training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + +from .transports.base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import PipelineServiceGrpcAsyncIOTransport +from .client import PipelineServiceClient + + +class PipelineServiceAsyncClient: + """A service for creating and managing AI Platform's pipelines.""" + + _client: PipelineServiceClient + + DEFAULT_ENDPOINT = PipelineServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PipelineServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(PipelineServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(PipelineServiceClient.parse_endpoint_path) + model_path = staticmethod(PipelineServiceClient.model_path) + parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) + training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) + parse_training_pipeline_path = staticmethod( + PipelineServiceClient.parse_training_pipeline_path + ) + + common_billing_account_path = staticmethod( + PipelineServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PipelineServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(PipelineServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + PipelineServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + PipelineServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PipelineServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(PipelineServiceClient.common_project_path) + parse_common_project_path = staticmethod( + PipelineServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(PipelineServiceClient.common_location_path) + parse_common_location_path = staticmethod( + PipelineServiceClient.parse_common_location_path + ) + + from_service_account_file = PipelineServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PipelineServiceTransport: + """Return the transport used by the client instance. + + Returns: + PipelineServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the pipeline service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.PipelineServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = PipelineServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: + r"""Creates a TrainingPipeline. A created + TrainingPipeline right away will be attempted to be run. + + Args: + request (:class:`~.pipeline_service.CreateTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.CreateTrainingPipeline``. + parent (:class:`str`): + Required. The resource name of the Location to create + the TrainingPipeline in. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + training_pipeline (:class:`~.gca_training_pipeline.TrainingPipeline`): + Required. The TrainingPipeline to + create. + This corresponds to the ``training_pipeline`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gca_training_pipeline.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with + training a Model. It always executes the training task, + and optionally may also export data from AI Platform's + Dataset which becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, training_pipeline]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.CreateTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if training_pipeline is not None: + request.training_pipeline = training_pipeline + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: + r"""Gets a TrainingPipeline. + + Args: + request (:class:`~.pipeline_service.GetTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.GetTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline resource. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.training_pipeline.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with + training a Model. It always executes the training task, + and optionally may also export data from AI Platform's + Dataset which becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.GetTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesAsyncPager: + r"""Lists TrainingPipelines in a Location. + + Args: + request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): + The request object. Request message for + ``PipelineService.ListTrainingPipelines``. + parent (:class:`str`): + Required. The resource name of the Location to list the + TrainingPipelines from. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListTrainingPipelinesAsyncPager: + Response message for + ``PipelineService.ListTrainingPipelines`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.ListTrainingPipelinesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_training_pipelines, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTrainingPipelinesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a TrainingPipeline. + + Args: + request (:class:`~.pipeline_service.DeleteTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.DeleteTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline resource to + be deleted. Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.DeleteTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on + the TrainingPipeline. The server makes a best effort to cancel + the pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetTrainingPipeline`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the TrainingPipeline is not deleted; + instead it becomes a pipeline with a + ``TrainingPipeline.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``TrainingPipeline.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`~.pipeline_service.CancelTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.CancelTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline to cancel. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.CancelTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("PipelineServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index 2530414b9a..e3e7d6aeda 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import operation as gca_operation @@ -41,8 +48,9 @@ from google.protobuf import timestamp_pb2 as timestamp # type: ignore from google.rpc import status_pb2 as status # type: ignore -from .transports.base import PipelineServiceTransport +from .transports.base import PipelineServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import PipelineServiceGrpcTransport +from .transports.grpc_asyncio import PipelineServiceGrpcAsyncIOTransport class PipelineServiceClientMeta(type): @@ -57,6 +65,7 @@ class PipelineServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[PipelineServiceTransport]] _transport_registry["grpc"] = PipelineServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: """Return an appropriate transport class. @@ -80,8 +89,38 @@ def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTranspor class PipelineServiceClient(metaclass=PipelineServiceClientMeta): """A service for creating and managing AI Platform's pipelines.""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -104,6 +143,31 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> PipelineServiceTransport: + """Return the transport used by the client instance. + + Returns: + PipelineServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str, location: str, endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str, str]: + """Parse a endpoint path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def model_path(project: str, location: str, model: str,) -> str: """Return a fully-qualified model string.""" @@ -111,6 +175,15 @@ def model_path(project: str, location: str, model: str,) -> str: project=project, location=location, model=model, ) + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def training_pipeline_path( project: str, location: str, training_pipeline: str, @@ -120,12 +193,81 @@ def training_pipeline_path( project=project, location=location, training_pipeline=training_pipeline, ) + @staticmethod + def parse_training_pipeline_path(path: str) -> Dict[str, str]: + """Parse a training_pipeline path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, PipelineServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the pipeline service client. @@ -138,26 +280,102 @@ def __init__( transport (Union[str, ~.PipelineServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PipelineServiceTransport): - if credentials: + # transport is a PipelineServiceTransport instance. + if credentials or client_options.credentials_file: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_training_pipeline( @@ -210,28 +428,36 @@ def create_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, training_pipeline]): + has_flattened_params = any([parent, training_pipeline]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = pipeline_service.CreateTrainingPipelineRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.CreateTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.CreateTrainingPipelineRequest): + request = pipeline_service.CreateTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if training_pipeline is not None: - request.training_pipeline = training_pipeline + if parent is not None: + request.parent = parent + if training_pipeline is not None: + request.training_pipeline = training_pipeline # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_training_pipeline, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -283,27 +509,29 @@ def get_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = pipeline_service.GetTrainingPipelineRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.GetTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.GetTrainingPipelineRequest): + request = pipeline_service.GetTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_training_pipeline, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_training_pipeline] # Certain fields should be provided within the metadata header; # add these here. @@ -358,27 +586,29 @@ def list_training_pipelines( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = pipeline_service.ListTrainingPipelinesRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.ListTrainingPipelinesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.ListTrainingPipelinesRequest): + request = pipeline_service.ListTrainingPipelinesRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_training_pipelines, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_training_pipelines] # Certain fields should be provided within the metadata header; # add these here. @@ -392,7 +622,7 @@ def list_training_pipelines( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListTrainingPipelinesPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -452,26 +682,34 @@ def delete_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = pipeline_service.DeleteTrainingPipelineRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.DeleteTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.DeleteTrainingPipelineRequest): + request = pipeline_service.DeleteTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_training_pipeline, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -533,26 +771,34 @@ def cancel_training_pipeline( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = pipeline_service.CancelTrainingPipelineRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.CancelTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.CancelTrainingPipelineRequest): + request = pipeline_service.CancelTrainingPipelineRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.cancel_training_pipeline, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.cancel_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -562,13 +808,13 @@ def cancel_training_pipeline( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index 0db54250ef..98e5a51a17 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import pipeline_service from google.cloud.aiplatform_v1beta1.types import training_pipeline @@ -41,12 +41,11 @@ class ListTrainingPipelinesPager: def __init__( self, - method: Callable[ - [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse, - ], + method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], request: pipeline_service.ListTrainingPipelinesRequest, response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -57,10 +56,13 @@ def __init__( The initial request object. response (:class:`~.pipeline_service.ListTrainingPipelinesResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = pipeline_service.ListTrainingPipelinesRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -70,7 +72,7 @@ def pages(self) -> Iterable[pipeline_service.ListTrainingPipelinesResponse]: yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: @@ -79,3 +81,73 @@ def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTrainingPipelinesAsyncPager: + """A pager for iterating through ``list_training_pipelines`` requests. + + This class thinly wraps an initial + :class:`~.pipeline_service.ListTrainingPipelinesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``training_pipelines`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTrainingPipelines`` requests and continue to iterate + through the ``training_pipelines`` field on the + corresponding responses. + + All the usual :class:`~.pipeline_service.ListTrainingPipelinesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[pipeline_service.ListTrainingPipelinesResponse] + ], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): + The initial request object. + response (:class:`~.pipeline_service.ListTrainingPipelinesResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = pipeline_service.ListTrainingPipelinesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[training_pipeline.TrainingPipeline]: + async def async_generator(): + async for page in self.pages: + for response in page.training_pipelines: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py index 615b2c1025..d9d71a892b 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import PipelineServiceTransport from .grpc import PipelineServiceGrpcTransport +from .grpc_asyncio import PipelineServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] _transport_registry["grpc"] = PipelineServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport __all__ = ( "PipelineServiceTransport", "PipelineServiceGrpcTransport", + "PipelineServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 5696ede4d7..1b235635f1 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -31,7 +35,17 @@ from google.protobuf import empty_pb2 as empty # type: ignore -class PipelineServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class PipelineServiceTransport(abc.ABC): """Abstract transport class for PipelineService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -41,6 +55,11 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, ) -> None: """Instantiate the transport. @@ -51,6 +70,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -59,57 +89,115 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_training_pipeline: gapic_v1.method.wrap_method( + self.create_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.get_training_pipeline: gapic_v1.method.wrap_method( + self.get_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.list_training_pipelines: gapic_v1.method.wrap_method( + self.list_training_pipelines, + default_timeout=None, + client_info=client_info, + ), + self.delete_training_pipeline: gapic_v1.method.wrap_method( + self.delete_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.cancel_training_pipeline: gapic_v1.method.wrap_method( + self.cancel_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property def create_training_pipeline( self, ) -> typing.Callable[ [pipeline_service.CreateTrainingPipelineRequest], - gca_training_pipeline.TrainingPipeline, + typing.Union[ + gca_training_pipeline.TrainingPipeline, + typing.Awaitable[gca_training_pipeline.TrainingPipeline], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_training_pipeline( self, ) -> typing.Callable[ [pipeline_service.GetTrainingPipelineRequest], - training_pipeline.TrainingPipeline, + typing.Union[ + training_pipeline.TrainingPipeline, + typing.Awaitable[training_pipeline.TrainingPipeline], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def list_training_pipelines( self, ) -> typing.Callable[ [pipeline_service.ListTrainingPipelinesRequest], - pipeline_service.ListTrainingPipelinesResponse, + typing.Union[ + pipeline_service.ListTrainingPipelinesResponse, + typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def delete_training_pipeline( self, ) -> typing.Callable[ - [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation + [pipeline_service.DeleteTrainingPipelineRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def cancel_training_pipeline( self, - ) -> typing.Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: - raise NotImplementedError + ) -> typing.Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() __all__ = ("PipelineServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index 7ce95caab7..4fc6389449 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -31,7 +35,7 @@ from google.longrunning import operations_pb2 as operations # type: ignore from google.protobuf import empty_pb2 as empty # type: ignore -from .base import PipelineServiceTransport +from .base import PipelineServiceTransport, DEFAULT_CLIENT_INFO class PipelineServiceGrpcTransport(PipelineServiceTransport): @@ -47,12 +51,21 @@ class PipelineServiceGrpcTransport(PipelineServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - channel: grpc.Channel = None + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -64,28 +77,123 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. + self._ssl_channel_credentials = ssl_channel_credentials + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - **kwargs + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -95,30 +203,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..2e6f51e1a3 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import pipeline_service +from google.cloud.aiplatform_v1beta1.types import training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import PipelineServiceGrpcTransport + + +class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): + """gRPC AsyncIO backend transport for PipelineService. + + A service for creating and managing AI Platform's pipelines. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + Awaitable[gca_training_pipeline.TrainingPipeline], + ]: + r"""Return a callable for the create training pipeline method over gRPC. + + Creates a TrainingPipeline. A created + TrainingPipeline right away will be attempted to be run. + + Returns: + Callable[[~.CreateTrainingPipelineRequest], + Awaitable[~.TrainingPipeline]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CreateTrainingPipeline", + request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, + response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, + ) + return self._stubs["create_training_pipeline"] + + @property + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + Awaitable[training_pipeline.TrainingPipeline], + ]: + r"""Return a callable for the get training pipeline method over gRPC. + + Gets a TrainingPipeline. + + Returns: + Callable[[~.GetTrainingPipelineRequest], + Awaitable[~.TrainingPipeline]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/GetTrainingPipeline", + request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, + response_deserializer=training_pipeline.TrainingPipeline.deserialize, + ) + return self._stubs["get_training_pipeline"] + + @property + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ]: + r"""Return a callable for the list training pipelines method over gRPC. + + Lists TrainingPipelines in a Location. + + Returns: + Callable[[~.ListTrainingPipelinesRequest], + Awaitable[~.ListTrainingPipelinesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/ListTrainingPipelines", + request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, + response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, + ) + return self._stubs["list_training_pipelines"] + + @property + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete training pipeline method over gRPC. + + Deletes a TrainingPipeline. + + Returns: + Callable[[~.DeleteTrainingPipelineRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/DeleteTrainingPipeline", + request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_training_pipeline"] + + @property + def cancel_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], Awaitable[empty.Empty] + ]: + r"""Return a callable for the cancel training pipeline method over gRPC. + + Cancels a TrainingPipeline. Starts asynchronous cancellation on + the TrainingPipeline. The server makes a best effort to cancel + the pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetTrainingPipeline`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the TrainingPipeline is not deleted; + instead it becomes a pipeline with a + ``TrainingPipeline.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``TrainingPipeline.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelTrainingPipelineRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PipelineService/CancelTrainingPipeline", + request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_training_pipeline"] + + +__all__ = ("PipelineServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py index 9e3af89360..0c847693e0 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/__init__.py @@ -16,5 +16,9 @@ # from .client import PredictionServiceClient +from .async_client import PredictionServiceAsyncClient -__all__ = ("PredictionServiceClient",) +__all__ = ( + "PredictionServiceClient", + "PredictionServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py new file mode 100644 index 0000000000..c82146bafa --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -0,0 +1,386 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import prediction_service +from google.protobuf import struct_pb2 as struct # type: ignore + +from .transports.base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport +from .client import PredictionServiceClient + + +class PredictionServiceAsyncClient: + """A service for online predictions and explanations.""" + + _client: PredictionServiceClient + + DEFAULT_ENDPOINT = PredictionServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PredictionServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(PredictionServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(PredictionServiceClient.parse_endpoint_path) + + common_billing_account_path = staticmethod( + PredictionServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PredictionServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + PredictionServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + PredictionServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PredictionServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(PredictionServiceClient.common_project_path) + parse_common_project_path = staticmethod( + PredictionServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(PredictionServiceClient.common_location_path) + parse_common_location_path = staticmethod( + PredictionServiceClient.parse_common_location_path + ) + + from_service_account_file = PredictionServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PredictionServiceTransport: + """Return the transport used by the client instance. + + Returns: + PredictionServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the prediction service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.PredictionServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = PredictionServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: + r"""Perform an online prediction. + + Args: + request (:class:`~.prediction_service.PredictRequest`): + The request object. Request message for + ``PredictionService.Predict``. + endpoint (:class:`str`): + Required. The name of the Endpoint requested to serve + the prediction. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + instances (:class:`Sequence[~.struct.Value]`): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + This corresponds to the ``instances`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + parameters (:class:`~.struct.Value`): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + This corresponds to the ``parameters`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.prediction_service.PredictResponse: + Response message for + ``PredictionService.Predict``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, instances, parameters]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = prediction_service.PredictRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if parameters is not None: + request.parameters = parameters + + if instances: + request.instances.extend(instances) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.predict, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def explain( + self, + request: prediction_service.ExplainRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + deployed_model_id: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.ExplainResponse: + r"""Perform an online explanation. + + If + ``deployed_model_id`` + is specified, the corresponding DeployModel must have + ``explanation_spec`` + populated. If + ``deployed_model_id`` + is not specified, all DeployedModels must have + ``explanation_spec`` + populated. Only deployed AutoML tabular Models have + explanation_spec. + + Args: + request (:class:`~.prediction_service.ExplainRequest`): + The request object. Request message for + ``PredictionService.Explain``. + endpoint (:class:`str`): + Required. The name of the Endpoint requested to serve + the explanation. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + instances (:class:`Sequence[~.struct.Value]`): + Required. The instances that are the input to the + explanation call. A DeployedModel may have an upper + limit on the number of instances it supports per + request, and when it is exceeded the explanation call + errors in case of AutoML Models, or, in case of customer + created Models, the behaviour is as documented by that + Model. The schema of any single instance may be + specified via Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``instance_schema_uri``. + This corresponds to the ``instances`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + parameters (:class:`~.struct.Value`): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1beta1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] + ``parameters_schema_uri``. + This corresponds to the ``parameters`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model_id (:class:`str`): + If specified, this ExplainRequest will be served by the + chosen DeployedModel, overriding + ``Endpoint.traffic_split``. + This corresponds to the ``deployed_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.prediction_service.ExplainResponse: + Response message for + ``PredictionService.Explain``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = prediction_service.ExplainRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if parameters is not None: + request.parameters = parameters + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + if instances: + request.instances.extend(instances) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.explain, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("PredictionServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index dbdf226471..9a5976d697 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -16,22 +16,29 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import prediction_service from google.protobuf import struct_pb2 as struct # type: ignore -from .transports.base import PredictionServiceTransport +from .transports.base import PredictionServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import PredictionServiceGrpcTransport +from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport class PredictionServiceClientMeta(type): @@ -46,6 +53,7 @@ class PredictionServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[PredictionServiceTransport]] _transport_registry["grpc"] = PredictionServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport def get_transport_class( cls, label: str = None, @@ -71,8 +79,38 @@ def get_transport_class( class PredictionServiceClient(metaclass=PredictionServiceClientMeta): """A service for online predictions and explanations.""" - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -95,12 +133,97 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> PredictionServiceTransport: + """Return the transport used by the client instance. + + Returns: + PredictionServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str, location: str, endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str, str]: + """Parse a endpoint path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, PredictionServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PredictionServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the prediction service client. @@ -113,26 +236,102 @@ def __init__( transport (Union[str, ~.PredictionServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, PredictionServiceTransport): - if credentials: + # transport is a PredictionServiceTransport instance. + if credentials or client_options.credentials_file: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def predict( @@ -200,28 +399,39 @@ def predict( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([endpoint, instances, parameters]): + has_flattened_params = any([endpoint, instances, parameters]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = prediction_service.PredictRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a prediction_service.PredictRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, prediction_service.PredictRequest): + request = prediction_service.PredictRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if instances is not None: - request.instances.extend(instances) - if parameters is not None: - request.parameters = parameters + if endpoint is not None: + request.endpoint = endpoint + if parameters is not None: + request.parameters = parameters + + if instances: + request.instances.extend(instances) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.predict, default_timeout=None, client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.predict] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. @@ -244,11 +454,13 @@ def explain( ) -> prediction_service.ExplainResponse: r"""Perform an online explanation. - If [ExplainRequest.deployed_model_id] is specified, the - corresponding DeployModel must have + If + ``deployed_model_id`` + is specified, the corresponding DeployModel must have ``explanation_spec`` - populated. If [ExplainRequest.deployed_model_id] is not - specified, all DeployedModels must have + populated. If + ``deployed_model_id`` + is not specified, all DeployedModels must have ``explanation_spec`` populated. Only deployed AutoML tabular Models have explanation_spec. @@ -312,32 +524,41 @@ def explain( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any( - [endpoint, instances, parameters, deployed_model_id] - ): + has_flattened_params = any([endpoint, instances, parameters, deployed_model_id]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = prediction_service.ExplainRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a prediction_service.ExplainRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, prediction_service.ExplainRequest): + request = prediction_service.ExplainRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if endpoint is not None: - request.endpoint = endpoint - if instances is not None: - request.instances.extend(instances) - if parameters is not None: - request.parameters = parameters - if deployed_model_id is not None: - request.deployed_model_id = deployed_model_id + if endpoint is not None: + request.endpoint = endpoint + if parameters is not None: + request.parameters = parameters + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + if instances: + request.instances.extend(instances) # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.explain, default_timeout=None, client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.explain] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), ) # Send the request. @@ -348,13 +569,13 @@ def explain( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("PredictionServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py index 33eefca757..7eb32ea86d 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py @@ -20,14 +20,17 @@ from .base import PredictionServiceTransport from .grpc import PredictionServiceGrpcTransport +from .grpc_asyncio import PredictionServiceGrpcAsyncIOTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] _transport_registry["grpc"] = PredictionServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport __all__ = ( "PredictionServiceTransport", "PredictionServiceGrpcTransport", + "PredictionServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index 58d508474a..cdec1c11e5 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -17,14 +17,28 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service -class PredictionServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class PredictionServiceTransport(abc.ABC): """Abstract transport class for PredictionService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -34,6 +48,11 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, ) -> None: """Instantiate the transport. @@ -44,6 +63,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -52,27 +82,61 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.predict: gapic_v1.method.wrap_method( + self.predict, default_timeout=None, client_info=client_info, + ), + self.explain: gapic_v1.method.wrap_method( + self.explain, default_timeout=None, client_info=client_info, + ), + } + @property def predict( self, ) -> typing.Callable[ - [prediction_service.PredictRequest], prediction_service.PredictResponse + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def explain( self, ) -> typing.Callable[ - [prediction_service.ExplainRequest], prediction_service.ExplainResponse + [prediction_service.ExplainRequest], + typing.Union[ + prediction_service.ExplainResponse, + typing.Awaitable[prediction_service.ExplainResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() __all__ = ("PredictionServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index 55824a233c..1a102e1a61 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -15,16 +15,20 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import prediction_service -from .base import PredictionServiceTransport +from .base import PredictionServiceTransport, DEFAULT_CLIENT_INFO class PredictionServiceGrpcTransport(PredictionServiceTransport): @@ -40,12 +44,21 @@ class PredictionServiceGrpcTransport(PredictionServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - channel: grpc.Channel = None + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -57,28 +70,123 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. + self._ssl_channel_credentials = ssl_channel_credentials + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - **kwargs + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -88,30 +196,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property @@ -152,11 +267,13 @@ def explain( Perform an online explanation. - If [ExplainRequest.deployed_model_id] is specified, the - corresponding DeployModel must have + If + ``deployed_model_id`` + is specified, the corresponding DeployModel must have ``explanation_spec`` - populated. If [ExplainRequest.deployed_model_id] is not - specified, all DeployedModels must have + populated. If + ``deployed_model_id`` + is not specified, all DeployedModels must have ``explanation_spec`` populated. Only deployed AutoML tabular Models have explanation_spec. diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..a0785007db --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -0,0 +1,306 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import prediction_service + +from .base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import PredictionServiceGrpcTransport + + +class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): + """gRPC AsyncIO backend transport for PredictionService. + + A service for online predictions and explanations. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse], + ]: + r"""Return a callable for the predict method over gRPC. + + Perform an online prediction. + + Returns: + Callable[[~.PredictRequest], + Awaitable[~.PredictResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Predict", + request_serializer=prediction_service.PredictRequest.serialize, + response_deserializer=prediction_service.PredictResponse.deserialize, + ) + return self._stubs["predict"] + + @property + def explain( + self, + ) -> Callable[ + [prediction_service.ExplainRequest], + Awaitable[prediction_service.ExplainResponse], + ]: + r"""Return a callable for the explain method over gRPC. + + Perform an online explanation. + + If + ``deployed_model_id`` + is specified, the corresponding DeployModel must have + ``explanation_spec`` + populated. If + ``deployed_model_id`` + is not specified, all DeployedModels must have + ``explanation_spec`` + populated. Only deployed AutoML tabular Models have + explanation_spec. + + Returns: + Callable[[~.ExplainRequest], + Awaitable[~.ExplainResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "explain" not in self._stubs: + self._stubs["explain"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.PredictionService/Explain", + request_serializer=prediction_service.ExplainRequest.serialize, + response_deserializer=prediction_service.ExplainResponse.deserialize, + ) + return self._stubs["explain"] + + +__all__ = ("PredictionServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py index 8f429cd5eb..49e9cdf0a0 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/__init__.py @@ -16,5 +16,9 @@ # from .client import SpecialistPoolServiceClient +from .async_client import SpecialistPoolServiceAsyncClient -__all__ = ("SpecialistPoolServiceClient",) +__all__ = ( + "SpecialistPoolServiceClient", + "SpecialistPoolServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py new file mode 100644 index 0000000000..77f40bd4ad --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -0,0 +1,637 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool_service +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + +from .transports.base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport +from .client import SpecialistPoolServiceClient + + +class SpecialistPoolServiceAsyncClient: + """A service for creating and managing Customer SpecialistPools. + When customers start Data Labeling jobs, they can reuse/create + Specialist Pools to bring their own Specialists to label the + data. Customers can add/remove Managers for the Specialist Pool + on Cloud console, then Managers will get email notifications to + manage Specialists and tasks on CrowdCompute console. + """ + + _client: SpecialistPoolServiceClient + + DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT + + specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.specialist_pool_path + ) + parse_specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.parse_specialist_pool_path + ) + + common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(SpecialistPoolServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + SpecialistPoolServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + SpecialistPoolServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + SpecialistPoolServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(SpecialistPoolServiceClient.common_project_path) + parse_common_project_path = staticmethod( + SpecialistPoolServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod( + SpecialistPoolServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + SpecialistPoolServiceClient.parse_common_location_path + ) + + from_service_account_file = SpecialistPoolServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> SpecialistPoolServiceTransport: + """Return the transport used by the client instance. + + Returns: + SpecialistPoolServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(SpecialistPoolServiceClient).get_transport_class, + type(SpecialistPoolServiceClient), + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the specialist pool service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.SpecialistPoolServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = SpecialistPoolServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a SpecialistPool. + + Args: + request (:class:`~.specialist_pool_service.CreateSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.CreateSpecialistPool``. + parent (:class:`str`): + Required. The parent Project name for the new + SpecialistPool. The form is + ``projects/{project}/locations/{location}``. + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): + Required. The SpecialistPool to + create. + This corresponds to the ``specialist_pool`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.gca_specialist_pool.SpecialistPool`: + SpecialistPool represents customers' own workforce to + work on their data labeling jobs. It includes a group of + specialist managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the jobs + using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, specialist_pool]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.CreateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if specialist_pool is not None: + request.specialist_pool = specialist_pool + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_specialist_pool.SpecialistPool, + metadata_type=specialist_pool_service.CreateSpecialistPoolOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: + r"""Gets a SpecialistPool. + + Args: + request (:class:`~.specialist_pool_service.GetSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.GetSpecialistPool``. + name (:class:`str`): + Required. The name of the SpecialistPool resource. The + form is + + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.specialist_pool.SpecialistPool: + SpecialistPool represents customers' + own workforce to work on their data + labeling jobs. It includes a group of + specialist managers who are responsible + for managing the labelers in this pool + as well as customers' data labeling jobs + associated with this pool. + Customers create specialist pool as well + as start data labeling jobs on Cloud, + managers and labelers work with the jobs + using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.GetSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsAsyncPager: + r"""Lists SpecialistPools in a Location. + + Args: + request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): + The request object. Request message for + ``SpecialistPoolService.ListSpecialistPools``. + parent (:class:`str`): + Required. The name of the SpecialistPool's parent + resource. Format: + ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.pagers.ListSpecialistPoolsAsyncPager: + Response message for + ``SpecialistPoolService.ListSpecialistPools``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.ListSpecialistPoolsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_specialist_pools, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListSpecialistPoolsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a SpecialistPool as well as all Specialists + in the pool. + + Args: + request (:class:`~.specialist_pool_service.DeleteSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.DeleteSpecialistPool``. + name (:class:`str`): + Required. The resource name of the SpecialistPool to + delete. Format: + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}`` + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.empty.Empty`: A generic empty message that + you can re-use to avoid defining duplicated empty + messages in your APIs. A typical example is to use it as + the request or the response type of an API method. For + instance: + + :: + + service Foo { + rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); + } + + The JSON representation for ``Empty`` is empty JSON + object ``{}``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates a SpecialistPool. + + Args: + request (:class:`~.specialist_pool_service.UpdateSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.UpdateSpecialistPool``. + specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): + Required. The SpecialistPool which + replaces the resource on the server. + This corresponds to the ``specialist_pool`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`~.field_mask.FieldMask`): + Required. The update mask applies to + the resource. + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`~.gca_specialist_pool.SpecialistPool`: + SpecialistPool represents customers' own workforce to + work on their data labeling jobs. It includes a group of + specialist managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the jobs + using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([specialist_pool, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if specialist_pool is not None: + request.specialist_pool = specialist_pool + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_specialist_pool.SpecialistPool, + metadata_type=specialist_pool_service.UpdateSpecialistPoolOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("SpecialistPoolServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index ddc9c26ab9..efc19eca12 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -16,17 +16,24 @@ # from collections import OrderedDict -from typing import Dict, Sequence, Tuple, Type, Union +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union import pkg_resources -import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import client_options as client_options_lib # type: ignore from google.api_core import exceptions # type: ignore from google.api_core import gapic_v1 # type: ignore from google.api_core import retry as retries # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore -from google.api_core import operation as ga_operation +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import specialist_pool @@ -35,8 +42,9 @@ from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from .transports.base import SpecialistPoolServiceTransport +from .transports.base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO from .transports.grpc import SpecialistPoolServiceGrpcTransport +from .transports.grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport class SpecialistPoolServiceClientMeta(type): @@ -51,6 +59,7 @@ class SpecialistPoolServiceClientMeta(type): OrderedDict() ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport + _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport def get_transport_class( cls, label: str = None, @@ -82,8 +91,38 @@ class SpecialistPoolServiceClient(metaclass=SpecialistPoolServiceClientMeta): manage Specialists and tasks on CrowdCompute console. """ - DEFAULT_OPTIONS = ClientOptions.ClientOptions( - api_endpoint="aiplatform.googleapis.com" + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT ) @classmethod @@ -106,6 +145,15 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file + @property + def transport(self) -> SpecialistPoolServiceTransport: + """Return the transport used by the client instance. + + Returns: + SpecialistPoolServiceTransport: The transport used by the client instance. + """ + return self._transport + @staticmethod def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: """Return a fully-qualified specialist_pool string.""" @@ -113,12 +161,81 @@ def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> project=project, location=location, specialist_pool=specialist_pool, ) + @staticmethod + def parse_specialist_pool_path(path: str) -> Dict[str, str]: + """Parse a specialist_pool path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + def __init__( self, *, - credentials: credentials.Credentials = None, - transport: Union[str, SpecialistPoolServiceTransport] = None, - client_options: ClientOptions.ClientOptions = DEFAULT_OPTIONS, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the specialist pool service client. @@ -131,26 +248,102 @@ def __init__( transport (Union[str, ~.SpecialistPoolServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (ClientOptions): Custom options for the client. + client_options (client_options_lib.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. """ if isinstance(client_options, dict): - client_options = ClientOptions.from_dict(client_options) + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + ssl_credentials = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + import grpc # type: ignore + + cert, key = client_options.client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + is_mtls = True + else: + creds = SslCredentials() + is_mtls = creds.is_mtls + ssl_credentials = creds.ssl_credentials if is_mtls else None + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport # instance provides an extensibility point for unusual situations. if isinstance(transport, SpecialistPoolServiceTransport): - if credentials: + # transport is a SpecialistPoolServiceTransport instance. + if credentials or client_options.credentials_file: raise ValueError( "When providing a transport instance, " "provide its credentials directly." ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) self._transport = transport else: Transport = type(self).get_transport_class(transport) self._transport = Transport( credentials=credentials, - host=client_options.api_endpoint or "aiplatform.googleapis.com", + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + ssl_channel_credentials=ssl_credentials, + quota_project_id=client_options.quota_project_id, + client_info=client_info, ) def create_specialist_pool( @@ -208,28 +401,36 @@ def create_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent, specialist_pool]): + has_flattened_params = any([parent, specialist_pool]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = specialist_pool_service.CreateSpecialistPoolRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.CreateSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.CreateSpecialistPoolRequest): + request = specialist_pool_service.CreateSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent - if specialist_pool is not None: - request.specialist_pool = specialist_pool + if parent is not None: + request.parent = parent + if specialist_pool is not None: + request.specialist_pool = specialist_pool # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.create_specialist_pool, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.create_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -294,27 +495,29 @@ def get_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = specialist_pool_service.GetSpecialistPoolRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.GetSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.GetSpecialistPoolRequest): + request = specialist_pool_service.GetSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.get_specialist_pool, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.get_specialist_pool] # Certain fields should be provided within the metadata header; # add these here. @@ -369,27 +572,29 @@ def list_specialist_pools( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([parent]): + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = specialist_pool_service.ListSpecialistPoolsRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.ListSpecialistPoolsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.ListSpecialistPoolsRequest): + request = specialist_pool_service.ListSpecialistPoolsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if parent is not None: - request.parent = parent + if parent is not None: + request.parent = parent # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.list_specialist_pools, - default_timeout=None, - client_info=_client_info, - ) + rpc = self._transport._wrapped_methods[self._transport.list_specialist_pools] # Certain fields should be provided within the metadata header; # add these here. @@ -403,7 +608,7 @@ def list_specialist_pools( # This method is paged; wrap the response in a pager, which provides # an `__iter__` convenience method. response = pagers.ListSpecialistPoolsPager( - method=rpc, request=request, response=response, + method=rpc, request=request, response=response, metadata=metadata, ) # Done; return the response. @@ -463,26 +668,34 @@ def delete_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([name]): + has_flattened_params = any([name]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.DeleteSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.DeleteSpecialistPoolRequest): + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if name is not None: - request.name = name + if name is not None: + request.name = name # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.delete_specialist_pool, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.delete_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -553,28 +766,38 @@ def update_specialist_pool( # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - if request is not None and any([specialist_pool, update_mask]): + has_flattened_params = any([specialist_pool, update_mask]) + if request is not None and has_flattened_params: raise ValueError( "If the `request` argument is set, then none of " "the individual field arguments should be set." ) - request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.UpdateSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.UpdateSpecialistPoolRequest): + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. + # If we have keyword arguments corresponding to fields on the + # request, apply these. - if specialist_pool is not None: - request.specialist_pool = specialist_pool - if update_mask is not None: - request.update_mask = update_mask + if specialist_pool is not None: + request.specialist_pool = specialist_pool + if update_mask is not None: + request.update_mask = update_mask # Wrap the RPC method; this adds retry and timeout information, # and friendly error handling. - rpc = gapic_v1.method.wrap_method( - self._transport.update_specialist_pool, - default_timeout=None, - client_info=_client_info, + rpc = self._transport._wrapped_methods[self._transport.update_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), ) # Send the request. @@ -593,13 +816,13 @@ def update_specialist_pool( try: - _client_info = gapic_v1.client_info.ClientInfo( + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=pkg_resources.get_distribution( "google-cloud-aiplatform", ).version, ) except pkg_resources.DistributionNotFound: - _client_info = gapic_v1.client_info.ClientInfo() + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() __all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index 012b76479b..ff2d84ac74 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -15,7 +15,7 @@ # limitations under the License. # -from typing import Any, Callable, Iterable +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple from google.cloud.aiplatform_v1beta1.types import specialist_pool from google.cloud.aiplatform_v1beta1.types import specialist_pool_service @@ -41,12 +41,11 @@ class ListSpecialistPoolsPager: def __init__( self, - method: Callable[ - [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse, - ], + method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], request: specialist_pool_service.ListSpecialistPoolsRequest, response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () ): """Instantiate the pager. @@ -57,10 +56,13 @@ def __init__( The initial request object. response (:class:`~.specialist_pool_service.ListSpecialistPoolsResponse`): The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. """ self._method = method self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) self._response = response + self._metadata = metadata def __getattr__(self, name: str) -> Any: return getattr(self._response, name) @@ -70,7 +72,7 @@ def pages(self) -> Iterable[specialist_pool_service.ListSpecialistPoolsResponse] yield self._response while self._response.next_page_token: self._request.page_token = self._response.next_page_token - self._response = self._method(self._request) + self._response = self._method(self._request, metadata=self._metadata) yield self._response def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: @@ -79,3 +81,73 @@ def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: def __repr__(self) -> str: return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListSpecialistPoolsAsyncPager: + """A pager for iterating through ``list_specialist_pools`` requests. + + This class thinly wraps an initial + :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``specialist_pools`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListSpecialistPools`` requests and continue to iterate + through the ``specialist_pools`` field on the + corresponding responses. + + All the usual :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] + ], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): + The initial request object. + response (:class:`~.specialist_pool_service.ListSpecialistPoolsResponse`): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[specialist_pool.SpecialistPool]: + async def async_generator(): + async for page in self.pages: + for response in page.specialist_pools: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py index c77d2d31a3..711f7fd1cc 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py @@ -20,6 +20,7 @@ from .base import SpecialistPoolServiceTransport from .grpc import SpecialistPoolServiceGrpcTransport +from .grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport # Compile a registry of transports. @@ -27,9 +28,11 @@ OrderedDict() ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport +_transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport __all__ = ( "SpecialistPoolServiceTransport", "SpecialistPoolServiceGrpcTransport", + "SpecialistPoolServiceGrpcAsyncIOTransport", ) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index effe36767e..30fbd3030f 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -17,8 +17,12 @@ import abc import typing +import pkg_resources -from google import auth +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore from google.api_core import operations_v1 # type: ignore from google.auth import credentials # type: ignore @@ -27,7 +31,17 @@ from google.longrunning import operations_pb2 as operations # type: ignore -class SpecialistPoolServiceTransport(metaclass=abc.ABCMeta): +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class SpecialistPoolServiceTransport(abc.ABC): """Abstract transport class for SpecialistPoolService.""" AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) @@ -37,6 +51,11 @@ def __init__( *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, ) -> None: """Instantiate the transport. @@ -47,6 +66,17 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. if ":" not in host: @@ -55,58 +85,110 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. - if credentials is None: - credentials, _ = auth.default(scopes=self.AUTH_SCOPES) + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) # Save the credentials. self._credentials = credentials + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_specialist_pool: gapic_v1.method.wrap_method( + self.create_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.get_specialist_pool: gapic_v1.method.wrap_method( + self.get_specialist_pool, default_timeout=None, client_info=client_info, + ), + self.list_specialist_pools: gapic_v1.method.wrap_method( + self.list_specialist_pools, + default_timeout=None, + client_info=client_info, + ), + self.delete_specialist_pool: gapic_v1.method.wrap_method( + self.delete_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.update_specialist_pool: gapic_v1.method.wrap_method( + self.update_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + } + @property def operations_client(self) -> operations_v1.OperationsClient: """Return the client designed to process long-running operations.""" - raise NotImplementedError + raise NotImplementedError() @property def create_specialist_pool( self, ) -> typing.Callable[ - [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def get_specialist_pool( self, ) -> typing.Callable[ [specialist_pool_service.GetSpecialistPoolRequest], - specialist_pool.SpecialistPool, + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def list_specialist_pools( self, ) -> typing.Callable[ [specialist_pool_service.ListSpecialistPoolsRequest], - specialist_pool_service.ListSpecialistPoolsResponse, + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ], ]: - raise NotImplementedError + raise NotImplementedError() @property def delete_specialist_pool( self, ) -> typing.Callable[ - [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() @property def update_specialist_pool( self, ) -> typing.Callable[ - [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], ]: - raise NotImplementedError + raise NotImplementedError() __all__ = ("SpecialistPoolServiceTransport",) diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 92cff5699c..2d1442ae33 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -15,11 +15,15 @@ # limitations under the License. # -from typing import Callable, Dict +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple from google.api_core import grpc_helpers # type: ignore from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore import grpc # type: ignore @@ -27,7 +31,7 @@ from google.cloud.aiplatform_v1beta1.types import specialist_pool_service from google.longrunning import operations_pb2 as operations # type: ignore -from .base import SpecialistPoolServiceTransport +from .base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): @@ -48,12 +52,21 @@ class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] + def __init__( self, *, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - channel: grpc.Channel = None + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: """Instantiate the transport. @@ -65,28 +78,123 @@ def __init__( are specified, the client will attempt to ascertain the credentials from the environment. This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. channel (Optional[grpc.Channel]): A ``Channel`` instance through which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ - # Sanity check: Ensure that channel and credentials are not both - # provided. + self._ssl_channel_credentials = ssl_channel_credentials + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. credentials = False - # Run the base constructor. - super().__init__(host=host, credentials=credentials) + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._stubs = {} # type: Dict[str, Callable] - # If a channel was explicitly provided, set it. - if channel: - self._grpc_channel = channel + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) @classmethod def create_channel( cls, host: str = "aiplatform.googleapis.com", credentials: credentials.Credentials = None, - **kwargs + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, ) -> grpc.Channel: """Create and return a gRPC channel object. Args: @@ -96,30 +204,37 @@ def create_channel( credentials identify this application to the service. If none are specified, the client will attempt to ascertain the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. kwargs (Optional[dict]): Keyword arguments, which are passed to the channel creation. Returns: grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. """ + scopes = scopes or cls.AUTH_SCOPES return grpc_helpers.create_channel( - host, credentials=credentials, scopes=cls.AUTH_SCOPES, **kwargs + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. + """Return the channel designed to connect to this service. """ - # Sanity check: Only create a new channel if we do not already - # have one. - if not hasattr(self, "_grpc_channel"): - self._grpc_channel = self.create_channel( - self._host, credentials=self._credentials, - ) - - # Return the channel from cache. return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..7d038edc4f --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -0,0 +1,407 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1beta1.types import specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import SpecialistPoolServiceGrpcTransport + + +class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): + """gRPC AsyncIO backend transport for SpecialistPoolService. + + A service for creating and managing Customer SpecialistPools. + When customers start Data Labeling jobs, they can reuse/create + Specialist Pools to bring their own Specialists to label the + data. Customers can add/remove Managers for the Specialist Pool + on Cloud console, then Managers will get email notifications to + manage Specialists and tasks on CrowdCompute console. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + warnings.warn( + "api_mtls_endpoint and client_cert_source are deprecated", + DeprecationWarning, + ) + + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if "operations_client" not in self.__dict__: + self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self.__dict__["operations_client"] + + @property + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the create specialist pool method over gRPC. + + Creates a SpecialistPool. + + Returns: + Callable[[~.CreateSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/CreateSpecialistPool", + request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_specialist_pool"] + + @property + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + Awaitable[specialist_pool.SpecialistPool], + ]: + r"""Return a callable for the get specialist pool method over gRPC. + + Gets a SpecialistPool. + + Returns: + Callable[[~.GetSpecialistPoolRequest], + Awaitable[~.SpecialistPool]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/GetSpecialistPool", + request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, + response_deserializer=specialist_pool.SpecialistPool.deserialize, + ) + return self._stubs["get_specialist_pool"] + + @property + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ]: + r"""Return a callable for the list specialist pools method over gRPC. + + Lists SpecialistPools in a Location. + + Returns: + Callable[[~.ListSpecialistPoolsRequest], + Awaitable[~.ListSpecialistPoolsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/ListSpecialistPools", + request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, + response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, + ) + return self._stubs["list_specialist_pools"] + + @property + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete specialist pool method over gRPC. + + Deletes a SpecialistPool as well as all Specialists + in the pool. + + Returns: + Callable[[~.DeleteSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/DeleteSpecialistPool", + request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_specialist_pool"] + + @property + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the update specialist pool method over gRPC. + + Updates a SpecialistPool. + + Returns: + Callable[[~.UpdateSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1beta1.SpecialistPoolService/UpdateSpecialistPool", + request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["update_specialist_pool"] + + +__all__ = ("SpecialistPoolServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index 93508415dc..97e5625d20 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -37,6 +37,7 @@ AutomaticResources, BatchDedicatedResources, ResourcesConsumed, + DiskSpec, ) from .deployed_model_ref import DeployedModelRef from .env_var import EnvVar @@ -48,6 +49,10 @@ ExplanationSpec, ExplanationParameters, SampledShapleyAttribution, + IntegratedGradientsAttribution, + XraiAttribution, + SmoothGradConfig, + FeatureNoiseSigma, ) from .model import ( Model, @@ -64,6 +69,20 @@ TimestampSplit, ) from .model_evaluation import ModelEvaluation +from .migratable_resource import MigratableResource +from .operation import ( + GenericOperationMetadata, + DeleteOperationMetadata, +) +from .migration_service import ( + SearchMigratableResourcesRequest, + SearchMigratableResourcesResponse, + BatchMigrateResourcesRequest, + MigrateResourceRequest, + BatchMigrateResourcesResponse, + MigrateResourceResponse, + BatchMigrateResourcesOperationMetadata, +) from .batch_prediction_job import BatchPredictionJob from .custom_job import ( CustomJob, @@ -114,10 +133,6 @@ ) from .user_action_reference import UserActionReference from .annotation import Annotation -from .operation import ( - GenericOperationMetadata, - DeleteOperationMetadata, -) from .endpoint import ( Endpoint, DeployedModel, @@ -221,6 +236,7 @@ "AutomaticResources", "BatchDedicatedResources", "ResourcesConsumed", + "DiskSpec", "DeployedModelRef", "EnvVar", "ExplanationMetadata", @@ -230,6 +246,10 @@ "ExplanationSpec", "ExplanationParameters", "SampledShapleyAttribution", + "IntegratedGradientsAttribution", + "XraiAttribution", + "SmoothGradConfig", + "FeatureNoiseSigma", "Model", "PredictSchemata", "ModelContainerSpec", @@ -241,6 +261,16 @@ "PredefinedSplit", "TimestampSplit", "ModelEvaluation", + "MigratableResource", + "GenericOperationMetadata", + "DeleteOperationMetadata", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "BatchMigrateResourcesRequest", + "MigrateResourceRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceResponse", + "BatchMigrateResourcesOperationMetadata", "BatchPredictionJob", "CustomJob", "CustomJobSpec", @@ -283,8 +313,6 @@ "CancelBatchPredictionJobRequest", "UserActionReference", "Annotation", - "GenericOperationMetadata", - "DeleteOperationMetadata", "Endpoint", "DeployedModel", "PredictRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index 34f3edfa5e..7734fcc512 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -88,14 +88,21 @@ class Annotation(proto.Message): """ name = proto.Field(proto.STRING, number=1) + payload_schema_uri = proto.Field(proto.STRING, number=2) + payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + etag = proto.Field(proto.STRING, number=8) + annotation_source = proto.Field( proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, ) + labels = proto.MapField(proto.STRING, proto.STRING, number=6) diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index 4719fb12ce..a5a4b3d489 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -52,9 +52,13 @@ class AnnotationSpec(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + etag = proto.Field(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index 332faaa6a9..625bf83155 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -21,6 +21,7 @@ from google.cloud.aiplatform_v1beta1.types import ( completion_stats as gca_completion_stats, ) +from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources @@ -98,30 +99,41 @@ class BatchPredictionJob(proto.Message): Generate explanation along with the batch prediction results. - This can only be set to true for AutoML tabular Models, and - only when the output destination is BigQuery. When it's - true, the batch prediction output will include a column - named ``feature_attributions``. - - For AutoML tabular Models, the value of the - ``feature_attributions`` column is a struct that maps from - string to number. The keys in the map are the names of the - features. The values in the map are the how much the - features contribute to the predicted result. Features are - defined as follows: - - - A scalar column defines a feature of the same name as the - column. - - - A struct column defines multiple features, one feature - per leaf field. The feature name is the fully qualified - path for the leaf field, separated by ".". For example a - column ``key1`` in the format of {"value1": {"prop1": - number}, "value2": number} defines two features: - ``key1.value1.prop1`` and ``key1.value2`` - - Attributions of each feature is represented as an extra - column in the batch prediction output BigQuery table. + When it's true, the batch prediction output will change + based on the [output + format][BatchPredictionJob.output_config.predictions_format]: + + - ``bigquery``: output will include a column named + ``explanation``. The value is a struct that conforms to + the + ``Explanation`` + object. + - ``jsonl``: The JSON objects on each line will include an + additional entry keyed ``explanation``. The value of the + entry is a JSON object that conforms to the + ``Explanation`` + object. + - ``csv``: Generating explanations for CSV format is not + supported. + explanation_spec (~.explanation.ExplanationSpec): + Explanation configuration for this BatchPredictionJob. Can + only be specified if + ``generate_explanation`` + is set to ``true``. It's invalid to specified it with + generate_explanation set to false or unset. + + This value overrides the value of + ``Model.explanation_spec``. + All fields of + ``explanation_spec`` + are optional in the request. If a field of + ``explanation_spec`` + is not populated, the value of the same field of + ``Model.explanation_spec`` + is inherited. The corresponding + ``Model.explanation_spec`` + must be populated, otherwise explanation for this Model is + not allowed. output_info (~.batch_prediction_job.BatchPredictionJob.OutputInfo): Output only. Information further describing the output of this job. @@ -198,10 +210,14 @@ class InputConfig(proto.Message): ``supported_input_storage_formats``. """ - gcs_source = proto.Field(proto.MESSAGE, number=2, message=io.GcsSource,) + gcs_source = proto.Field( + proto.MESSAGE, number=2, oneof="source", message=io.GcsSource, + ) + bigquery_source = proto.Field( - proto.MESSAGE, number=3, message=io.BigQuerySource, + proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, ) + instances_format = proto.Field(proto.STRING, number=1) class OutputConfig(proto.Message): @@ -271,11 +287,16 @@ class OutputConfig(proto.Message): """ gcs_destination = proto.Field( - proto.MESSAGE, number=2, message=io.GcsDestination, + proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, ) + bigquery_destination = proto.Field( - proto.MESSAGE, number=3, message=io.BigQueryDestination, + proto.MESSAGE, + number=3, + oneof="destination", + message=io.BigQueryDestination, ) + predictions_format = proto.Field(proto.STRING, number=1) class OutputInfo(proto.Message): @@ -293,40 +314,68 @@ class OutputInfo(proto.Message): prediction output is written. """ - gcs_output_directory = proto.Field(proto.STRING, number=1) - bigquery_output_dataset = proto.Field(proto.STRING, number=2) + gcs_output_directory = proto.Field( + proto.STRING, number=1, oneof="output_location" + ) + + bigquery_output_dataset = proto.Field( + proto.STRING, number=2, oneof="output_location" + ) name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + model = proto.Field(proto.STRING, number=3) + input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) + model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) + output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) + dedicated_resources = proto.Field( proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, ) + manual_batch_tuning_parameters = proto.Field( proto.MESSAGE, number=8, message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, ) + generate_explanation = proto.Field(proto.BOOL, number=23) + + explanation_spec = proto.Field( + proto.MESSAGE, number=25, message=explanation.ExplanationSpec, + ) + output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) + error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) + partial_failures = proto.RepeatedField( proto.MESSAGE, number=12, message=status.Status, ) + resources_consumed = proto.Field( proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, ) + completion_stats = proto.Field( proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, ) + create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) + labels = proto.MapField(proto.STRING, proto.STRING, number=19) diff --git a/google/cloud/aiplatform_v1beta1/types/completion_stats.py b/google/cloud/aiplatform_v1beta1/types/completion_stats.py index 22f1fa7975..165be59634 100644 --- a/google/cloud/aiplatform_v1beta1/types/completion_stats.py +++ b/google/cloud/aiplatform_v1beta1/types/completion_stats.py @@ -46,7 +46,9 @@ class CompletionStats(proto.Message): """ successful_count = proto.Field(proto.INT64, number=1) + failed_count = proto.Field(proto.INT64, number=2) + incomplete_count = proto.Field(proto.INT64, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 3d466ab72f..2d8745538c 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -86,14 +86,23 @@ class CustomJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) + state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + labels = proto.MapField(proto.STRING, proto.STRING, number=11) @@ -106,6 +115,24 @@ class CustomJobSpec(proto.Message): including machine type and Docker image. scheduling (~.custom_job.Scheduling): Scheduling options for a CustomJob. + service_account (str): + Specifies the service account for workload + run-as account. Users submitting jobs must have + act-as permission on this run-as account. + network (str): + The full name of the Compute Engine + `network `__ + to which the Job should be peered. For example, + projects/12345/global/networks/myVPC. + + [Format](https://cloud.google.com/compute/docs/reference/rest/v1/networks/insert) + is of the form projects/{project}/global/networks/{network}. + Where {project} is a project number, as in '12345', and + {network} is network name. + + Private services access must already be configured for the + network. If left unspecified, the job is not peered with any + network. base_output_directory (~.io.GcsDestination): The Google Cloud Storage location to store the output of this CustomJob or HyperparameterTuningJob. For @@ -142,7 +169,13 @@ class CustomJobSpec(proto.Message): worker_pool_specs = proto.RepeatedField( proto.MESSAGE, number=1, message="WorkerPoolSpec", ) + scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) + + service_account = proto.Field(proto.STRING, number=4) + + network = proto.Field(proto.STRING, number=5) + base_output_directory = proto.Field( proto.MESSAGE, number=6, message=io.GcsDestination, ) @@ -162,17 +195,28 @@ class WorkerPoolSpec(proto.Message): replica_count (int): Required. The number of worker replicas to use for this worker pool. + disk_spec (~.machine_resources.DiskSpec): + Disk spec. """ - container_spec = proto.Field(proto.MESSAGE, number=6, message="ContainerSpec",) + container_spec = proto.Field( + proto.MESSAGE, number=6, oneof="task", message="ContainerSpec", + ) + python_package_spec = proto.Field( - proto.MESSAGE, number=7, message="PythonPackageSpec", + proto.MESSAGE, number=7, oneof="task", message="PythonPackageSpec", ) + machine_spec = proto.Field( proto.MESSAGE, number=1, message=machine_resources.MachineSpec, ) + replica_count = proto.Field(proto.INT64, number=2) + disk_spec = proto.Field( + proto.MESSAGE, number=5, message=machine_resources.DiskSpec, + ) + class ContainerSpec(proto.Message): r"""The spec of a Container. @@ -192,7 +236,9 @@ class ContainerSpec(proto.Message): """ image_uri = proto.Field(proto.STRING, number=1) + command = proto.RepeatedField(proto.STRING, number=2) + args = proto.RepeatedField(proto.STRING, number=3) @@ -221,8 +267,11 @@ class PythonPackageSpec(proto.Message): """ executor_image_uri = proto.Field(proto.STRING, number=1) + package_uris = proto.RepeatedField(proto.STRING, number=2) + python_module = proto.Field(proto.STRING, number=3) + args = proto.RepeatedField(proto.STRING, number=4) @@ -242,6 +291,7 @@ class Scheduling(proto.Message): """ timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) + restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index 418e8cc739..e43a944d94 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -69,10 +69,15 @@ class DataItem(proto.Message): """ name = proto.Field(proto.STRING, number=1) + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + labels = proto.MapField(proto.STRING, proto.STRING, number=3) + payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + etag = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index 19da27e6a9..af1bcdd871 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -21,6 +21,7 @@ from google.cloud.aiplatform_v1beta1.types import job_state from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore from google.type import money_pb2 as money # type: ignore @@ -98,6 +99,10 @@ class DataLabelingJob(proto.Message): update_time (~.timestamp.Timestamp): Output only. Timestamp when this DataLabelingJob was updated most recently. + error (~.status.Status): + Output only. DataLabelingJob errors. It is only populated + when job's state is ``JOB_STATE_FAILED`` or + ``JOB_STATE_CANCELLED``. labels (Sequence[~.data_labeling_job.DataLabelingJob.LabelsEntry]): The labels with user-defined metadata to organize your DataLabelingJobs. @@ -128,20 +133,37 @@ class DataLabelingJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + datasets = proto.RepeatedField(proto.STRING, number=3) + annotation_labels = proto.MapField(proto.STRING, proto.STRING, number=12) + labeler_count = proto.Field(proto.INT32, number=4) + instruction_uri = proto.Field(proto.STRING, number=5) + inputs_schema_uri = proto.Field(proto.STRING, number=6) + inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) + state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) + labeling_progress = proto.Field(proto.INT32, number=13) + current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) + create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) + + error = proto.Field(proto.MESSAGE, number=22, message=status.Status,) + labels = proto.MapField(proto.STRING, proto.STRING, number=11) + specialist_pools = proto.RepeatedField(proto.STRING, number=16) + active_learning_config = proto.Field( proto.MESSAGE, number=21, message="ActiveLearningConfig", ) @@ -172,9 +194,16 @@ class ActiveLearningConfig(proto.Message): select DataItems. """ - max_data_item_count = proto.Field(proto.INT64, number=1) - max_data_item_percentage = proto.Field(proto.INT32, number=2) + max_data_item_count = proto.Field( + proto.INT64, number=1, oneof="human_labeling_budget" + ) + + max_data_item_percentage = proto.Field( + proto.INT32, number=2, oneof="human_labeling_budget" + ) + sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) + training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) @@ -204,8 +233,14 @@ class SampleStrategy(proto.Enum): SAMPLE_STRATEGY_UNSPECIFIED = 0 UNCERTAINTY = 1 - initial_batch_sample_percentage = proto.Field(proto.INT32, number=1) - following_batch_sample_percentage = proto.Field(proto.INT32, number=3) + initial_batch_sample_percentage = proto.Field( + proto.INT32, number=1, oneof="initial_batch_sample_size" + ) + + following_batch_sample_percentage = proto.Field( + proto.INT32, number=3, oneof="following_batch_sample_size" + ) + sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 3675d8f42a..76f6462f40 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -83,12 +83,19 @@ class Dataset(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + metadata_schema_uri = proto.Field(proto.STRING, number=3) + metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + etag = proto.Field(proto.STRING, number=6) + labels = proto.MapField(proto.STRING, proto.STRING, number=7) @@ -124,8 +131,12 @@ class ImportDataConfig(proto.Message): Object `__. """ - gcs_source = proto.Field(proto.MESSAGE, number=1, message=io.GcsSource,) + gcs_source = proto.Field( + proto.MESSAGE, number=1, oneof="source", message=io.GcsSource, + ) + data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) + import_schema_uri = proto.Field(proto.STRING, number=4) @@ -139,12 +150,13 @@ class ExportDataConfig(proto.Message): written to. In the given directory a new directory will be created with name: ``export-data--`` - where timestamp is in YYYYMMDDHHMMSS format. All export - output will be written into that directory. Inside that - directory, annotations with the same schema will be grouped - into sub directories which are named with the corresponding - annotations' schema title. Inside these sub directories, a - schema.yaml will be created to describe the output format. + where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 + format. All export output will be written into that + directory. Inside that directory, annotations with the same + schema will be grouped into sub directories which are named + with the corresponding annotations' schema title. Inside + these sub directories, a schema.yaml will be created to + describe the output format. annotations_filter (str): A filter on Annotations of the Dataset. Only Annotations on to-be-exported DataItems(specified by [data_items_filter][]) @@ -153,7 +165,10 @@ class ExportDataConfig(proto.Message): ``ListAnnotations``. """ - gcs_destination = proto.Field(proto.MESSAGE, number=1, message=io.GcsDestination,) + gcs_destination = proto.Field( + proto.MESSAGE, number=1, oneof="destination", message=io.GcsDestination, + ) + annotations_filter = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 56b51a97cb..7160b7b52f 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -64,6 +64,7 @@ class CreateDatasetRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) @@ -93,6 +94,7 @@ class GetDatasetRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -117,6 +119,7 @@ class UpdateDatasetRequest(proto.Message): """ dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -147,10 +150,15 @@ class ListDatasetsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + order_by = proto.Field(proto.STRING, number=6) @@ -173,6 +181,7 @@ def raw_page(self): datasets = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_dataset.Dataset, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -205,6 +214,7 @@ class ImportDataRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) + import_configs = proto.RepeatedField( proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, ) @@ -243,6 +253,7 @@ class ExportDataRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) + export_config = proto.Field( proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, ) @@ -277,6 +288,7 @@ class ExportDataOperationMetadata(proto.Message): generic_metadata = proto.Field( proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) + gcs_output_directory = proto.Field(proto.STRING, number=2) @@ -304,10 +316,15 @@ class ListDataItemsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + order_by = proto.Field(proto.STRING, number=6) @@ -330,6 +347,7 @@ def raw_page(self): data_items = proto.RepeatedField( proto.MESSAGE, number=1, message=data_item.DataItem, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -347,6 +365,7 @@ class GetAnnotationSpecRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -375,10 +394,15 @@ class ListAnnotationsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + order_by = proto.Field(proto.STRING, number=6) @@ -401,6 +425,7 @@ def raw_page(self): annotations = proto.RepeatedField( proto.MESSAGE, number=1, message=annotation.Annotation, ) + next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py index 6a7f18850f..b0ec7010a2 100644 --- a/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py +++ b/google/cloud/aiplatform_v1beta1/types/deployed_model_ref.py @@ -35,6 +35,7 @@ class DeployedModelRef(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) + deployed_model_id = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index 315a9de179..f1ba6ed85d 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -82,15 +82,23 @@ class Endpoint(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + description = proto.Field(proto.STRING, number=3) + deployed_models = proto.RepeatedField( proto.MESSAGE, number=4, message="DeployedModel", ) + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) + etag = proto.Field(proto.STRING, number=6) + labels = proto.MapField(proto.STRING, proto.STRING, number=7) + create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) @@ -137,9 +145,16 @@ class DeployedModel(proto.Message): ``Model.explanation_spec`` must be populated, otherwise explanation for this Model is not allowed. - - Currently, only AutoML tabular Models support - explanation_spec. + service_account (str): + The service account that the DeployedModel's container runs + as. Specify the email address of the service account. If + this service account is not specified, the container runs as + a service account that doesn't have access to the resource + project. + + Users deploying the Model must have the + ``iam.serviceAccounts.actAs`` permission on this service + account. enable_container_logging (bool): If true, the container of the DeployedModel instances will send ``stderr`` and ``stdout`` streams to Stackdriver @@ -159,19 +174,35 @@ class DeployedModel(proto.Message): """ dedicated_resources = proto.Field( - proto.MESSAGE, number=7, message=machine_resources.DedicatedResources, + proto.MESSAGE, + number=7, + oneof="prediction_resources", + message=machine_resources.DedicatedResources, ) + automatic_resources = proto.Field( - proto.MESSAGE, number=8, message=machine_resources.AutomaticResources, + proto.MESSAGE, + number=8, + oneof="prediction_resources", + message=machine_resources.AutomaticResources, ) + id = proto.Field(proto.STRING, number=1) + model = proto.Field(proto.STRING, number=2) + display_name = proto.Field(proto.STRING, number=3) + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + explanation_spec = proto.Field( proto.MESSAGE, number=9, message=explanation.ExplanationSpec, ) + + service_account = proto.Field(proto.STRING, number=11) + enable_container_logging = proto.Field(proto.BOOL, number=12) + enable_access_logging = proto.Field(proto.BOOL, number=13) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index 43e8eacdfb..4bc9f35594 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -57,6 +57,7 @@ class CreateEndpointRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) @@ -135,9 +136,13 @@ class ListEndpointsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -161,6 +166,7 @@ def raw_page(self): endpoints = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -178,6 +184,7 @@ class UpdateEndpointRequest(proto.Message): """ endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -230,9 +237,11 @@ class DeployModelRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) + deployed_model = proto.Field( proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, ) + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) @@ -289,7 +298,9 @@ class UndeployModelRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) + deployed_model_id = proto.Field(proto.STRING, number=2) + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/env_var.py b/google/cloud/aiplatform_v1beta1/types/env_var.py index 0c22313d63..207e8275cd 100644 --- a/google/cloud/aiplatform_v1beta1/types/env_var.py +++ b/google/cloud/aiplatform_v1beta1/types/env_var.py @@ -43,6 +43,7 @@ class EnvVar(proto.Message): """ name = proto.Field(proto.STRING, number=1) + value = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index 5e20ef2699..7a495fff1e 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -31,17 +31,20 @@ "ExplanationSpec", "ExplanationParameters", "SampledShapleyAttribution", + "IntegratedGradientsAttribution", + "XraiAttribution", + "SmoothGradConfig", + "FeatureNoiseSigma", }, ) class Explanation(proto.Message): - r"""Explanation of a ``prediction`` produced - by the Model on a given + r"""Explanation of a prediction (provided in + ``PredictResponse.predictions``) + produced by the Model on a given ``instance``. - Currently, only AutoML tabular Models support explanation. - Attributes: attributions (Sequence[~.explanation.Attribution]): Output only. Feature attributions grouped by predicted @@ -56,6 +59,16 @@ class Explanation(proto.Message): ``Attribution.output_index`` can be used to identify which output this attribution is explaining. + + If users set + ``ExplanationParameters.top_k``, + the attributions are sorted by + ``instance_output_value`` + in descending order. If + ``ExplanationParameters.output_indices`` + is specified, the attributions are stored by + ``Attribution.output_index`` + in the same order as they appear in the output_indices. """ attributions = proto.RepeatedField(proto.MESSAGE, number=1, message="Attribution",) @@ -64,8 +77,6 @@ class Explanation(proto.Message): class ModelExplanation(proto.Message): r"""Aggregated explanation metrics for a Model over a set of instances. - Currently, only AutoML tabular Models support aggregated - explanation. Attributes: mean_attributions (Sequence[~.explanation.Attribution]): @@ -114,9 +125,8 @@ class Attribution(proto.Message): The field name of the output is determined by the key in ``ExplanationMetadata.outputs``. - If the Model predicted output is a tensor value (for - example, an ndarray), this is the value in the output - located by + If the Model's predicted output has multiple dimensions + (rank > 1), this is the value in the output located by ``output_index``. If there are multiple baselines, their output values are @@ -127,16 +137,15 @@ class Attribution(proto.Message): name of the output is determined by the key in ``ExplanationMetadata.outputs``. - If the Model predicted output is a tensor value (for - example, an ndarray), this is the value in the output - located by + If the Model predicted output has multiple dimensions, this + is the value in the output located by ``output_index``. feature_attributions (~.struct.Value): Output only. Attributions of each explained feature. Features are extracted from the [prediction instances][google.cloud.aiplatform.v1beta1.ExplainRequest.instances] - according to [explanation input - metadata][google.cloud.aiplatform.v1beta1.ExplanationMetadata.inputs]. + according to [explanation metadata for + inputs][google.cloud.aiplatform.v1beta1.ExplanationMetadata.inputs]. The value is a struct, whose keys are the name of the feature. The values are how much the feature in the @@ -174,11 +183,11 @@ class Attribution(proto.Message): output. If the prediction output is a scalar value, output_index is - not populated. If the prediction output is a tensor value - (for example, an ndarray), the length of output_index is the - same as the number of dimensions of the output. The i-th - element in output_index is the element index of the i-th - dimension of the output vector. Indexes start from 0. + not populated. If the prediction output has multiple + dimensions, the length of the output_index list is the same + as the number of dimensions of the output. The i-th element + in output_index is the element index of the i-th dimension + of the output vector. Indices start from 0. output_display_name (str): Output only. The display name of the output identified by ``output_index``, @@ -195,24 +204,48 @@ class Attribution(proto.Message): caused by approximation used in the explanation method. Lower value means more precise attributions. - For Sampled Shapley - ``attribution``, - increasing - ``path_count`` - might reduce the error. + - For [Sampled Shapley + attribution][ExplanationParameters.sampled_shapley_attribution], + increasing + ``path_count`` + may reduce the error. + - For [Integrated Gradients + attribution][ExplanationParameters.integrated_gradients_attribution], + increasing + ``step_count`` + may reduce the error. + - For [XRAI + attribution][ExplanationParameters.xrai_attribution], + increasing + ``step_count`` + may reduce the error. + + Refer to AI Explanations Whitepaper for more details: + + https://storage.googleapis.com/cloud-ai-whitepapers/AI%20Explainability%20Whitepaper.pdf + output_name (str): + Output only. Name of the explain output. Specified as the + key in + ``ExplanationMetadata.outputs``. """ baseline_output_value = proto.Field(proto.DOUBLE, number=1) + instance_output_value = proto.Field(proto.DOUBLE, number=2) + feature_attributions = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + output_index = proto.RepeatedField(proto.INT32, number=4) + output_display_name = proto.Field(proto.STRING, number=5) + approximation_error = proto.Field(proto.DOUBLE, number=6) + output_name = proto.Field(proto.STRING, number=7) + class ExplanationSpec(proto.Message): r"""Specification of Model explanation. - Currently, only AutoML tabular Models support explanation. Attributes: parameters (~.explanation.ExplanationParameters): @@ -224,6 +257,7 @@ class ExplanationSpec(proto.Message): """ parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) + metadata = proto.Field( proto.MESSAGE, number=2, message=explanation_metadata.ExplanationMetadata, ) @@ -238,13 +272,69 @@ class ExplanationParameters(proto.Message): Shapley values for features that contribute to the label being predicted. A sampling strategy is used to approximate the value rather than - considering all subsets of features. + considering all subsets of features. Refer to + this paper for model details: + https://arxiv.org/abs/1306.4265. + integrated_gradients_attribution (~.explanation.IntegratedGradientsAttribution): + An attribution method that computes Aumann- + hapley values taking advantage of the model's + fully differentiable structure. Refer to this + paper for more details: + https://arxiv.org/abs/1703.01365 + xrai_attribution (~.explanation.XraiAttribution): + An attribution method that redistributes + Integrated Gradients attribution to segmented + regions, taking advantage of the model's fully + differentiable structure. Refer to this paper + for more details: + https://arxiv.org/abs/1906.02825 + XRAI currently performs better on natural + images, like a picture of a house or an animal. + If the images are taken in artificial + environments, like a lab or manufacturing line, + or from diagnostic equipment, like x-rays or + quality-control cameras, use Integrated + Gradients instead. + top_k (int): + If populated, returns attributions for top K + indices of outputs (defaults to 1). Only applies + to Models that predicts more than one outputs + (e,g, multi-class Models). When set to -1, + returns explanations for all outputs. + output_indices (~.struct.ListValue): + If populated, only returns attributions that have + ``output_index`` contained in + output_indices. It must be an ndarray of integers, with the + same shape of the output it's explaining. + + If not populated, returns attributions for + ``top_k`` + indices of outputs. If neither top_k nor output_indeices is + populated, returns the argmax index of the outputs. + + Only applicable to Models that predict multiple outputs + (e,g, multi-class Models that predict multiple classes). """ sampled_shapley_attribution = proto.Field( - proto.MESSAGE, number=1, message="SampledShapleyAttribution", + proto.MESSAGE, number=1, oneof="method", message="SampledShapleyAttribution", ) + integrated_gradients_attribution = proto.Field( + proto.MESSAGE, + number=2, + oneof="method", + message="IntegratedGradientsAttribution", + ) + + xrai_attribution = proto.Field( + proto.MESSAGE, number=3, oneof="method", message="XraiAttribution", + ) + + top_k = proto.Field(proto.INT32, number=4) + + output_indices = proto.Field(proto.MESSAGE, number=5, message=struct.ListValue,) + class SampledShapleyAttribution(proto.Message): r"""An attribution method that approximates Shapley values for @@ -263,4 +353,163 @@ class SampledShapleyAttribution(proto.Message): path_count = proto.Field(proto.INT32, number=1) +class IntegratedGradientsAttribution(proto.Message): + r"""An attribution method that computes the Aumann-Shapley value + taking advantage of the model's fully differentiable structure. + Refer to this paper for more details: + https://arxiv.org/abs/1703.01365 + + Attributes: + step_count (int): + Required. The number of steps for approximating the path + integral. A good value to start is 50 and gradually increase + until the sum to diff property is within the desired error + range. + + Valid range of its value is [1, 100], inclusively. + smooth_grad_config (~.explanation.SmoothGradConfig): + Config for SmoothGrad approximation of + gradients. + When enabled, the gradients are approximated by + averaging the gradients from noisy samples in + the vicinity of the inputs. Adding noise can + help improve the computed gradients. Refer to + this paper for more details: + https://arxiv.org/pdf/1706.03825.pdf + """ + + step_count = proto.Field(proto.INT32, number=1) + + smooth_grad_config = proto.Field( + proto.MESSAGE, number=2, message="SmoothGradConfig", + ) + + +class XraiAttribution(proto.Message): + r"""An explanation method that redistributes Integrated Gradients + attributions to segmented regions, taking advantage of the model's + fully differentiable structure. Refer to this paper for more + details: https://arxiv.org/abs/1906.02825 + + Only supports image Models (``modality`` is + IMAGE). + + Attributes: + step_count (int): + Required. The number of steps for approximating the path + integral. A good value to start is 50 and gradually increase + until the sum to diff property is met within the desired + error range. + + Valid range of its value is [1, 100], inclusively. + smooth_grad_config (~.explanation.SmoothGradConfig): + Config for SmoothGrad approximation of + gradients. + When enabled, the gradients are approximated by + averaging the gradients from noisy samples in + the vicinity of the inputs. Adding noise can + help improve the computed gradients. Refer to + this paper for more details: + https://arxiv.org/pdf/1706.03825.pdf + """ + + step_count = proto.Field(proto.INT32, number=1) + + smooth_grad_config = proto.Field( + proto.MESSAGE, number=2, message="SmoothGradConfig", + ) + + +class SmoothGradConfig(proto.Message): + r"""Config for SmoothGrad approximation of gradients. + When enabled, the gradients are approximated by averaging the + gradients from noisy samples in the vicinity of the inputs. + Adding noise can help improve the computed gradients. Refer to + this paper for more details: + https://arxiv.org/pdf/1706.03825.pdf + + Attributes: + noise_sigma (float): + This is a single float value and will be used to add noise + to all the features. Use this field when all features are + normalized to have the same distribution: scale to range [0, + 1], [-1, 1] or z-scoring, where features are normalized to + have 0-mean and 1-variance. Refer to this doc for more + details about normalization: + + https://developers.google.com/machine-learning/data-prep/transform/normalization. + + For best results the recommended value is about 10% - 20% of + the standard deviation of the input feature. Refer to + section 3.2 of the SmoothGrad paper: + https://arxiv.org/pdf/1706.03825.pdf. Defaults to 0.1. + + If the distribution is different per feature, set + ``feature_noise_sigma`` + instead for each feature. + feature_noise_sigma (~.explanation.FeatureNoiseSigma): + This is similar to + ``noise_sigma``, + but provides additional flexibility. A separate noise sigma + can be provided for each feature, which is useful if their + distributions are different. No noise is added to features + that are not set. If this field is unset, + ``noise_sigma`` + will be used for all features. + noisy_sample_count (int): + The number of gradient samples to use for approximation. The + higher this number, the more accurate the gradient is, but + the runtime complexity increases by this factor as well. + Valid range of its value is [1, 50]. Defaults to 3. + """ + + noise_sigma = proto.Field(proto.FLOAT, number=1, oneof="GradientNoiseSigma") + + feature_noise_sigma = proto.Field( + proto.MESSAGE, + number=2, + oneof="GradientNoiseSigma", + message="FeatureNoiseSigma", + ) + + noisy_sample_count = proto.Field(proto.INT32, number=3) + + +class FeatureNoiseSigma(proto.Message): + r"""Noise sigma by features. Noise sigma represents the standard + deviation of the gaussian kernel that will be used to add noise + to interpolated inputs prior to computing gradients. + + Attributes: + noise_sigma (Sequence[~.explanation.FeatureNoiseSigma.NoiseSigmaForFeature]): + Noise sigma per feature. No noise is added to + features that are not set. + """ + + class NoiseSigmaForFeature(proto.Message): + r"""Noise sigma for a single feature. + + Attributes: + name (str): + The name of the input feature for which noise sigma is + provided. The features are defined in [explanation metadata + inputs][google.cloud.aiplatform.v1beta1.ExplanationMetadata.inputs]. + sigma (float): + This represents the standard deviation of the Gaussian + kernel that will be used to add noise to the feature prior + to computing gradients. Similar to + ``noise_sigma`` + but represents the noise added to the current feature. + Defaults to 0.1. + """ + + name = proto.Field(proto.STRING, number=1) + + sigma = proto.Field(proto.FLOAT, number=2) + + noise_sigma = proto.RepeatedField( + proto.MESSAGE, number=1, message=NoiseSigmaForFeature, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index 1b9f005857..7261c064f8 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -40,12 +40,24 @@ class ExplanationMetadata(proto.Message): which has the name specified as the key in ``ExplanationMetadata.inputs``. The baseline of the empty feature is chosen by AI Platform. + + For AI Platform provided Tensorflow images, the key can be + any friendly name of the feature . Once specified, [ + featureAttributions][Attribution.feature_attributions] will + be keyed by this key (if not grouped with another feature). + + For custom images, the key must match with the key in + ``instance``. outputs (Sequence[~.explanation_metadata.ExplanationMetadata.OutputsEntry]): Required. Map from output names to output metadata. - Keys are the name of the output field in the - prediction to be explained. Currently only one - key is allowed. + For AI Platform provided Tensorflow images, keys + can be any string user defines. + + For custom images, keys are the name of the + output field in the prediction to be explained. + + Currently only one key is allowed. feature_attributions_schema_uri (str): Points to a YAML file stored on Google Cloud Storage describing the format of the [feature @@ -62,6 +74,11 @@ class ExplanationMetadata(proto.Message): class InputMetadata(proto.Message): r"""Metadata of the input of a feature. + Fields other than + ``InputMetadata.input_baselines`` + are applicable only for Models that are using AI Platform-provided + images for Tensorflow. + Attributes: input_baselines (Sequence[~.struct.Value]): Baseline inputs for this feature. @@ -71,20 +88,271 @@ class InputMetadata(proto.Message): specified, AI Platform returns the average attributions across them in [Attributions.baseline_attribution][]. - The element of the baselines must be in the same format as - the feature's input in the + For AI Platform provided Tensorflow images (both 1.x and + 2.x), the shape of each baseline must match the shape of the + input tensor. If a scalar is provided, we broadcast to the + same shape as the input tensor. + + For custom images, the element of the baselines must be in + the same format as the feature's input in the ``instance``[]. The schema of any single instance may be specified via Endpoint's DeployedModels' [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. + input_tensor_name (str): + Name of the input tensor for this feature. + Required and is only applicable to AI Platform + provided images for Tensorflow. + encoding (~.explanation_metadata.ExplanationMetadata.InputMetadata.Encoding): + Defines how the feature is encoded into the + input tensor. Defaults to IDENTITY. + modality (str): + Modality of the feature. Valid values are: + numeric, image. Defaults to numeric. + feature_value_domain (~.explanation_metadata.ExplanationMetadata.InputMetadata.FeatureValueDomain): + The domain details of the input feature + value. Like min/max, original mean or standard + deviation if normalized. + indices_tensor_name (str): + Specifies the index of the values of the input tensor. + Required when the input tensor is a sparse representation. + Refer to Tensorflow documentation for more details: + https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor. + dense_shape_tensor_name (str): + Specifies the shape of the values of the input if the input + is a sparse representation. Refer to Tensorflow + documentation for more details: + https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor. + index_feature_mapping (Sequence[str]): + A list of feature names for each index in the input tensor. + Required when the input + ``InputMetadata.encoding`` + is BAG_OF_FEATURES, BAG_OF_FEATURES_SPARSE, INDICATOR. + encoded_tensor_name (str): + Encoded tensor is a transformation of the input tensor. Must + be provided if choosing [Integrated Gradients + attribution][ExplanationParameters.integrated_gradients_attribution] + or [XRAI + attribution][google.cloud.aiplatform.v1beta1.ExplanationParameters.xrai_attribution] + and the input tensor is not differentiable. + + An encoded tensor is generated if the input tensor is + encoded by a lookup table. + encoded_baselines (Sequence[~.struct.Value]): + A list of baselines for the encoded tensor. + The shape of each baseline should match the + shape of the encoded tensor. If a scalar is + provided, AI Platform broadcast to the same + shape as the encoded tensor. + visualization (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization): + Visualization configurations for image + explanation. + group_name (str): + Name of the group that the input belongs to. Features with + the same group name will be treated as one feature when + computing attributions. Features grouped together can have + different shapes in value. If provided, there will be one + single attribution generated in [ + featureAttributions][Attribution.feature_attributions], + keyed by the group name. """ + class Encoding(proto.Enum): + r"""Defines how the feature is encoded to [encoded_tensor][]. Defaults + to IDENTITY. + """ + ENCODING_UNSPECIFIED = 0 + IDENTITY = 1 + BAG_OF_FEATURES = 2 + BAG_OF_FEATURES_SPARSE = 3 + INDICATOR = 4 + COMBINED_EMBEDDING = 5 + CONCAT_EMBEDDING = 6 + + class FeatureValueDomain(proto.Message): + r"""Domain details of the input feature value. Provides numeric + information about the feature, such as its range (min, max). If the + feature has been pre-processed, for example with z-scoring, then it + provides information about how to recover the original feature. For + example, if the input feature is an image and it has been + pre-processed to obtain 0-mean and stddev = 1 values, then + original_mean, and original_stddev refer to the mean and stddev of + the original feature (e.g. image tensor) from which input feature + (with mean = 0 and stddev = 1) was obtained. + + Attributes: + min_value (float): + The minimum permissible value for this + feature. + max_value (float): + The maximum permissible value for this + feature. + original_mean (float): + If this input feature has been normalized to a mean value of + 0, the original_mean specifies the mean value of the domain + prior to normalization. + original_stddev (float): + If this input feature has been normalized to a standard + deviation of 1.0, the original_stddev specifies the standard + deviation of the domain prior to normalization. + """ + + min_value = proto.Field(proto.FLOAT, number=1) + + max_value = proto.Field(proto.FLOAT, number=2) + + original_mean = proto.Field(proto.FLOAT, number=3) + + original_stddev = proto.Field(proto.FLOAT, number=4) + + class Visualization(proto.Message): + r"""Visualization configurations for image explanation. + + Attributes: + type_ (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.Type): + Type of the image visualization. Only applicable to + [Integrated Gradients attribution] + [ExplanationParameters.integrated_gradients_attribution]. + OUTLINES shows regions of attribution, while PIXELS shows + per-pixel attribution. Defaults to OUTLINES. + polarity (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.Polarity): + Whether to only highlight pixels with + positive contributions, negative or both. + Defaults to POSITIVE. + color_map (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.ColorMap): + The color scheme used for the highlighted areas. + + Defaults to PINK_GREEN for [Integrated Gradients + attribution][ExplanationParameters.integrated_gradients_attribution], + which shows positive attributions in green and negative in + pink. + + Defaults to VIRIDIS for [XRAI + attribution][google.cloud.aiplatform.v1beta1.ExplanationParameters.xrai_attribution], + which highlights the most influential regions in yellow and + the least influential in blue. + clip_percent_upperbound (float): + Excludes attributions above the specified percentile from + the highlighted areas. Using the clip_percent_upperbound and + clip_percent_lowerbound together can be useful for filtering + out noise and making it easier to see areas of strong + attribution. Defaults to 99.9. + clip_percent_lowerbound (float): + Excludes attributions below the specified + percentile, from the highlighted areas. Defaults + to 35. + overlay_type (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.OverlayType): + How the original image is displayed in the + visualization. Adjusting the overlay can help + increase visual clarity if the original image + makes it difficult to view the visualization. + Defaults to NONE. + """ + + class Type(proto.Enum): + r"""Type of the image visualization. Only applicable to [Integrated + Gradients attribution] + [ExplanationParameters.integrated_gradients_attribution]. + """ + TYPE_UNSPECIFIED = 0 + PIXELS = 1 + OUTLINES = 2 + + class Polarity(proto.Enum): + r"""Whether to only highlight pixels with positive contributions, + negative or both. Defaults to POSITIVE. + """ + POLARITY_UNSPECIFIED = 0 + POSITIVE = 1 + NEGATIVE = 2 + BOTH = 3 + + class ColorMap(proto.Enum): + r"""The color scheme used for highlighting areas.""" + COLOR_MAP_UNSPECIFIED = 0 + PINK_GREEN = 1 + VIRIDIS = 2 + RED = 3 + GREEN = 4 + RED_GREEN = 6 + PINK_WHITE_GREEN = 5 + + class OverlayType(proto.Enum): + r"""How the original image is displayed in the visualization.""" + OVERLAY_TYPE_UNSPECIFIED = 0 + NONE = 1 + ORIGINAL = 2 + GRAYSCALE = 3 + MASK_BLACK = 4 + + type_ = proto.Field( + proto.ENUM, + number=1, + enum="ExplanationMetadata.InputMetadata.Visualization.Type", + ) + + polarity = proto.Field( + proto.ENUM, + number=2, + enum="ExplanationMetadata.InputMetadata.Visualization.Polarity", + ) + + color_map = proto.Field( + proto.ENUM, + number=3, + enum="ExplanationMetadata.InputMetadata.Visualization.ColorMap", + ) + + clip_percent_upperbound = proto.Field(proto.FLOAT, number=4) + + clip_percent_lowerbound = proto.Field(proto.FLOAT, number=5) + + overlay_type = proto.Field( + proto.ENUM, + number=6, + enum="ExplanationMetadata.InputMetadata.Visualization.OverlayType", + ) + input_baselines = proto.RepeatedField( proto.MESSAGE, number=1, message=struct.Value, ) + input_tensor_name = proto.Field(proto.STRING, number=2) + + encoding = proto.Field( + proto.ENUM, number=3, enum="ExplanationMetadata.InputMetadata.Encoding", + ) + + modality = proto.Field(proto.STRING, number=4) + + feature_value_domain = proto.Field( + proto.MESSAGE, + number=5, + message="ExplanationMetadata.InputMetadata.FeatureValueDomain", + ) + + indices_tensor_name = proto.Field(proto.STRING, number=6) + + dense_shape_tensor_name = proto.Field(proto.STRING, number=7) + + index_feature_mapping = proto.RepeatedField(proto.STRING, number=8) + + encoded_tensor_name = proto.Field(proto.STRING, number=9) + + encoded_baselines = proto.RepeatedField( + proto.MESSAGE, number=10, message=struct.Value, + ) + + visualization = proto.Field( + proto.MESSAGE, + number=11, + message="ExplanationMetadata.InputMetadata.Visualization", + ) + + group_name = proto.Field(proto.STRING, number=12) + class OutputMetadata(proto.Message): r"""Metadata of the prediction output to be explained. @@ -116,19 +384,30 @@ class OutputMetadata(proto.Message): of the outputs, so that it can be located by ``Attribution.output_index`` for a specific output. + output_tensor_name (str): + Name of the output tensor. Required and is + only applicable to AI Platform provided images + for Tensorflow. """ index_display_name_mapping = proto.Field( - proto.MESSAGE, number=1, message=struct.Value, + proto.MESSAGE, number=1, oneof="display_name_mapping", message=struct.Value, + ) + + display_name_mapping_key = proto.Field( + proto.STRING, number=2, oneof="display_name_mapping" ) - display_name_mapping_key = proto.Field(proto.STRING, number=2) + + output_tensor_name = proto.Field(proto.STRING, number=3) inputs = proto.MapField( proto.STRING, proto.MESSAGE, number=1, message=InputMetadata, ) + outputs = proto.MapField( proto.STRING, proto.MESSAGE, number=2, message=OutputMetadata, ) + feature_attributions_schema_uri = proto.Field(proto.STRING, number=3) diff --git a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py index 171e37ad09..78af635e79 100644 --- a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py @@ -96,21 +96,35 @@ class HyperparameterTuningJob(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) + max_trial_count = proto.Field(proto.INT32, number=5) + parallel_trial_count = proto.Field(proto.INT32, number=6) + max_failed_trial_count = proto.Field(proto.INT32, number=7) + trial_job_spec = proto.Field( proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, ) + trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) + labels = proto.MapField(proto.STRING, proto.STRING, number=16) diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index 98e80c19a2..f64f07cbe3 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -76,6 +76,7 @@ class CreateCustomJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) @@ -132,9 +133,13 @@ class ListCustomJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -158,6 +163,7 @@ def raw_page(self): custom_jobs = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -201,6 +207,7 @@ class CreateDataLabelingJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + data_labeling_job = proto.Field( proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, ) @@ -261,10 +268,15 @@ class ListDataLabelingJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + order_by = proto.Field(proto.STRING, number=6) @@ -287,6 +299,7 @@ def raw_page(self): data_labeling_jobs = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -334,6 +347,7 @@ class CreateHyperparameterTuningJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + hyperparameter_tuning_job = proto.Field( proto.MESSAGE, number=2, @@ -396,9 +410,13 @@ class ListHyperparameterTuningJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -426,6 +444,7 @@ def raw_page(self): number=1, message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -473,6 +492,7 @@ class CreateBatchPredictionJobRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + batch_prediction_job = proto.Field( proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, ) @@ -533,9 +553,13 @@ class ListBatchPredictionJobsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -560,6 +584,7 @@ def raw_page(self): batch_prediction_jobs = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, ) + next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index 30b81e3efc..c71aca024e 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -31,6 +31,7 @@ "AutomaticResources", "BatchDedicatedResources", "ResourcesConsumed", + "DiskSpec", }, ) @@ -89,9 +90,11 @@ class MachineSpec(proto.Message): """ machine_type = proto.Field(proto.STRING, number=1) + accelerator_type = proto.Field( proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, ) + accelerator_count = proto.Field(proto.INT32, number=3) @@ -128,8 +131,10 @@ class DedicatedResources(proto.Message): as the default value. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) + min_replica_count = proto.Field(proto.INT32, number=2) + max_replica_count = proto.Field(proto.INT32, number=3) @@ -166,6 +171,7 @@ class AutomaticResources(proto.Message): """ min_replica_count = proto.Field(proto.INT32, number=1) + max_replica_count = proto.Field(proto.INT32, number=2) @@ -189,8 +195,10 @@ class BatchDedicatedResources(proto.Message): The default value is 10. """ - machine_spec = proto.Field(proto.MESSAGE, number=1, message=MachineSpec,) + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) + starting_replica_count = proto.Field(proto.INT32, number=2) + max_replica_count = proto.Field(proto.INT32, number=3) @@ -209,4 +217,23 @@ class ResourcesConsumed(proto.Message): replica_hours = proto.Field(proto.DOUBLE, number=1) +class DiskSpec(proto.Message): + r"""Represents the spec of disk options. + + Attributes: + boot_disk_type (str): + Type of the boot disk (default is "pd- + tandard"). Valid values: "pd-ssd" (Persistent + Disk Solid State Drive) or "pd-standard" + (Persistent Disk Hard Disk Drive). + boot_disk_size_gb (int): + Size in GB of the boot disk (default is + 100GB). + """ + + boot_disk_type = proto.Field(proto.STRING, number=1) + + boot_disk_size_gb = proto.Field(proto.INT32, number=2) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py new file mode 100644 index 0000000000..99a6e65a42 --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", manifest={"MigratableResource",}, +) + + +class MigratableResource(proto.Message): + r"""Represents one resource that exists in automl.googleapis.com, + datalabeling.googleapis.com or ml.googleapis.com. + + Attributes: + ml_engine_model_version (~.migratable_resource.MigratableResource.MlEngineModelVersion): + Output only. Represents one Version in + ml.googleapis.com. + automl_model (~.migratable_resource.MigratableResource.AutomlModel): + Output only. Represents one Model in + automl.googleapis.com. + automl_dataset (~.migratable_resource.MigratableResource.AutomlDataset): + Output only. Represents one Dataset in + automl.googleapis.com. + data_labeling_dataset (~.migratable_resource.MigratableResource.DataLabelingDataset): + Output only. Represents one Dataset in + datalabeling.googleapis.com. + last_migrate_time (~.timestamp.Timestamp): + Output only. Timestamp when last migrate + attempt on this MigratableResource started. Will + not be set if there's no migrate attempt on this + MigratableResource. + last_update_time (~.timestamp.Timestamp): + Output only. Timestamp when this + MigratableResource was last updated. + """ + + class MlEngineModelVersion(proto.Message): + r"""Represents one model Version in ml.googleapis.com. + + Attributes: + endpoint (str): + The ml.googleapis.com endpoint that this model Version + currently lives in. Example values: + + - ml.googleapis.com + - us-centrall-ml.googleapis.com + - europe-west4-ml.googleapis.com + - asia-east1-ml.googleapis.com + version (str): + Full resource name of ml engine model Version. Format: + ``projects/{project}/models/{model}/versions/{version}``. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + version = proto.Field(proto.STRING, number=2) + + class AutomlModel(proto.Message): + r"""Represents one Model in automl.googleapis.com. + + Attributes: + model (str): + Full resource name of automl Model. Format: + ``projects/{project}/locations/{location}/models/{model}``. + model_display_name (str): + The Model's display name in + automl.googleapis.com. + """ + + model = proto.Field(proto.STRING, number=1) + + model_display_name = proto.Field(proto.STRING, number=3) + + class AutomlDataset(proto.Message): + r"""Represents one Dataset in automl.googleapis.com. + + Attributes: + dataset (str): + Full resource name of automl Dataset. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}``. + dataset_display_name (str): + The Dataset's display name in + automl.googleapis.com. + """ + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=4) + + class DataLabelingDataset(proto.Message): + r"""Represents one Dataset in datalabeling.googleapis.com. + + Attributes: + dataset (str): + Full resource name of data labeling Dataset. Format: + ``projects/{project}/datasets/{dataset}``. + dataset_display_name (str): + The Dataset's display name in + datalabeling.googleapis.com. + data_labeling_annotated_datasets (Sequence[~.migratable_resource.MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset]): + The migratable AnnotatedDataset in + datalabeling.googleapis.com belongs to the data + labeling Dataset. + """ + + class DataLabelingAnnotatedDataset(proto.Message): + r"""Represents one AnnotatedDataset in + datalabeling.googleapis.com. + + Attributes: + annotated_dataset (str): + Full resource name of data labeling AnnotatedDataset. + Format: + + ``projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}``. + annotated_dataset_display_name (str): + The AnnotatedDataset's display name in + datalabeling.googleapis.com. + """ + + annotated_dataset = proto.Field(proto.STRING, number=1) + + annotated_dataset_display_name = proto.Field(proto.STRING, number=3) + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=4) + + data_labeling_annotated_datasets = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset", + ) + + ml_engine_model_version = proto.Field( + proto.MESSAGE, number=1, oneof="resource", message=MlEngineModelVersion, + ) + + automl_model = proto.Field( + proto.MESSAGE, number=2, oneof="resource", message=AutomlModel, + ) + + automl_dataset = proto.Field( + proto.MESSAGE, number=3, oneof="resource", message=AutomlDataset, + ) + + data_labeling_dataset = proto.Field( + proto.MESSAGE, number=4, oneof="resource", message=DataLabelingDataset, + ) + + last_migrate_time = proto.Field( + proto.MESSAGE, number=5, message=timestamp.Timestamp, + ) + + last_update_time = proto.Field( + proto.MESSAGE, number=6, message=timestamp.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/migration_service.py b/google/cloud/aiplatform_v1beta1/types/migration_service.py new file mode 100644 index 0000000000..46b0cdc66b --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/migration_service.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1beta1.types import ( + migratable_resource as gca_migratable_resource, +) +from google.cloud.aiplatform_v1beta1.types import operation + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", + manifest={ + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "BatchMigrateResourcesRequest", + "MigrateResourceRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceResponse", + "BatchMigrateResourcesOperationMetadata", + }, +) + + +class SearchMigratableResourcesRequest(proto.Message): + r"""Request message for + ``MigrationService.SearchMigratableResources``. + + Attributes: + parent (str): + Required. The location that the migratable resources should + be searched from. It's the AI Platform location that the + resources can be migrated to, not the resources' original + location. Format: + ``projects/{project}/locations/{location}`` + page_size (int): + The standard page size. + The default and maximum value is 100. + page_token (str): + The standard page token. + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + +class SearchMigratableResourcesResponse(proto.Message): + r"""Response message for + ``MigrationService.SearchMigratableResources``. + + Attributes: + migratable_resources (Sequence[~.gca_migratable_resource.MigratableResource]): + All migratable resources that can be migrated + to the location specified in the request. + next_page_token (str): + The standard next-page token. The migratable_resources may + not fill page_size in SearchMigratableResourcesRequest even + when there are subsequent pages. + """ + + @property + def raw_page(self): + return self + + migratable_resources = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_migratable_resource.MigratableResource, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class BatchMigrateResourcesRequest(proto.Message): + r"""Request message for + ``MigrationService.BatchMigrateResources``. + + Attributes: + parent (str): + Required. The location of the migrated resource will live + in. Format: ``projects/{project}/locations/{location}`` + migrate_resource_requests (Sequence[~.migration_service.MigrateResourceRequest]): + Required. The request messages specifying the + resources to migrate. They must be in the same + location as the destination. Up to 50 resources + can be migrated in one batch. + """ + + parent = proto.Field(proto.STRING, number=1) + + migrate_resource_requests = proto.RepeatedField( + proto.MESSAGE, number=2, message="MigrateResourceRequest", + ) + + +class MigrateResourceRequest(proto.Message): + r"""Config of migrating one resource from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + Attributes: + migrate_ml_engine_model_version_config (~.migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig): + Config for migrating Version in + ml.googleapis.com to AI Platform's Model. + migrate_automl_model_config (~.migration_service.MigrateResourceRequest.MigrateAutomlModelConfig): + Config for migrating Model in + automl.googleapis.com to AI Platform's Model. + migrate_automl_dataset_config (~.migration_service.MigrateResourceRequest.MigrateAutomlDatasetConfig): + Config for migrating Dataset in + automl.googleapis.com to AI Platform's Dataset. + migrate_data_labeling_dataset_config (~.migration_service.MigrateResourceRequest.MigrateDataLabelingDatasetConfig): + Config for migrating Dataset in + datalabeling.googleapis.com to AI Platform's + Dataset. + """ + + class MigrateMlEngineModelVersionConfig(proto.Message): + r"""Config for migrating version in ml.googleapis.com to AI + Platform's Model. + + Attributes: + endpoint (str): + Required. The ml.googleapis.com endpoint that this model + version should be migrated from. Example values: + + - ml.googleapis.com + + - us-centrall-ml.googleapis.com + + - europe-west4-ml.googleapis.com + + - asia-east1-ml.googleapis.com + model_version (str): + Required. Full resource name of ml engine model version. + Format: + ``projects/{project}/models/{model}/versions/{version}``. + model_display_name (str): + Required. Display name of the model in AI + Platform. System will pick a display name if + unspecified. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + model_version = proto.Field(proto.STRING, number=2) + + model_display_name = proto.Field(proto.STRING, number=3) + + class MigrateAutomlModelConfig(proto.Message): + r"""Config for migrating Model in automl.googleapis.com to AI + Platform's Model. + + Attributes: + model (str): + Required. Full resource name of automl Model. Format: + ``projects/{project}/locations/{location}/models/{model}``. + model_display_name (str): + Optional. Display name of the model in AI + Platform. System will pick a display name if + unspecified. + """ + + model = proto.Field(proto.STRING, number=1) + + model_display_name = proto.Field(proto.STRING, number=2) + + class MigrateAutomlDatasetConfig(proto.Message): + r"""Config for migrating Dataset in automl.googleapis.com to AI + Platform's Dataset. + + Attributes: + dataset (str): + Required. Full resource name of automl Dataset. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}``. + dataset_display_name (str): + Required. Display name of the Dataset in AI + Platform. System will pick a display name if + unspecified. + """ + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=2) + + class MigrateDataLabelingDatasetConfig(proto.Message): + r"""Config for migrating Dataset in datalabeling.googleapis.com + to AI Platform's Dataset. + + Attributes: + dataset (str): + Required. Full resource name of data labeling Dataset. + Format: ``projects/{project}/datasets/{dataset}``. + dataset_display_name (str): + Optional. Display name of the Dataset in AI + Platform. System will pick a display name if + unspecified. + migrate_data_labeling_annotated_dataset_configs (Sequence[~.migration_service.MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig]): + Optional. Configs for migrating + AnnotatedDataset in datalabeling.googleapis.com + to AI Platform's SavedQuery. The specified + AnnotatedDatasets have to belong to the + datalabeling Dataset. + """ + + class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): + r"""Config for migrating AnnotatedDataset in + datalabeling.googleapis.com to AI Platform's SavedQuery. + + Attributes: + annotated_dataset (str): + Required. Full resource name of data labeling + AnnotatedDataset. Format: + + ``projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}``. + """ + + annotated_dataset = proto.Field(proto.STRING, number=1) + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=2) + + migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig", + ) + + migrate_ml_engine_model_version_config = proto.Field( + proto.MESSAGE, + number=1, + oneof="request", + message=MigrateMlEngineModelVersionConfig, + ) + + migrate_automl_model_config = proto.Field( + proto.MESSAGE, number=2, oneof="request", message=MigrateAutomlModelConfig, + ) + + migrate_automl_dataset_config = proto.Field( + proto.MESSAGE, number=3, oneof="request", message=MigrateAutomlDatasetConfig, + ) + + migrate_data_labeling_dataset_config = proto.Field( + proto.MESSAGE, + number=4, + oneof="request", + message=MigrateDataLabelingDatasetConfig, + ) + + +class BatchMigrateResourcesResponse(proto.Message): + r"""Response message for + ``MigrationService.BatchMigrateResources``. + + Attributes: + migrate_resource_responses (Sequence[~.migration_service.MigrateResourceResponse]): + Successfully migrated resources. + """ + + migrate_resource_responses = proto.RepeatedField( + proto.MESSAGE, number=1, message="MigrateResourceResponse", + ) + + +class MigrateResourceResponse(proto.Message): + r"""Describes a successfully migrated resource. + + Attributes: + dataset (str): + Migrated Dataset's resource name. + model (str): + Migrated Model's resource name. + migratable_resource (~.gca_migratable_resource.MigratableResource): + Before migration, the identifier in + ml.googleapis.com, automl.googleapis.com or + datalabeling.googleapis.com. + """ + + dataset = proto.Field(proto.STRING, number=1, oneof="migrated_resource") + + model = proto.Field(proto.STRING, number=2, oneof="migrated_resource") + + migratable_resource = proto.Field( + proto.MESSAGE, number=3, message=gca_migratable_resource.MigratableResource, + ) + + +class BatchMigrateResourcesOperationMetadata(proto.Message): + r"""Runtime operation information for + ``MigrationService.BatchMigrateResources``. + + Attributes: + generic_metadata (~.operation.GenericOperationMetadata): + The common part of the operation metadata. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index 39cb44206e..21e8c41034 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -70,7 +70,7 @@ class Model(proto.Message): supported_export_formats (Sequence[~.model.Model.ExportFormat]): Output only. The formats in which this Model may be exported. If empty, this Model is not - avaiable for export. + available for export. training_pipeline (str): Output only. The resource name of the TrainingPipeline that uploaded this Model, if @@ -272,36 +272,55 @@ class ExportableContent(proto.Enum): IMAGE = 2 id = proto.Field(proto.STRING, number=1) + exportable_contents = proto.RepeatedField( proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", ) name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + description = proto.Field(proto.STRING, number=3) + predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) + metadata_schema_uri = proto.Field(proto.STRING, number=5) + metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) + supported_export_formats = proto.RepeatedField( proto.MESSAGE, number=20, message=ExportFormat, ) + training_pipeline = proto.Field(proto.STRING, number=7) + container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) + artifact_uri = proto.Field(proto.STRING, number=26) + supported_deployment_resources_types = proto.RepeatedField( proto.ENUM, number=10, enum=DeploymentResourcesType, ) + supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) + supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) + create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + deployed_models = proto.RepeatedField( proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, ) + explanation_spec = proto.Field( proto.MESSAGE, number=23, message=explanation.ExplanationSpec, ) + etag = proto.Field(proto.STRING, number=16) + labels = proto.MapField(proto.STRING, proto.STRING, number=17) @@ -363,75 +382,253 @@ class PredictSchemata(proto.Message): """ instance_schema_uri = proto.Field(proto.STRING, number=1) + parameters_schema_uri = proto.Field(proto.STRING, number=2) + prediction_schema_uri = proto.Field(proto.STRING, number=3) class ModelContainerSpec(proto.Message): - r"""Specification of the container to be deployed for this Model. The - ModelContainerSpec is based on the Kubernetes Container - `specification `__. + r"""Specification of a container for serving predictions. This message + is a subset of the Kubernetes Container v1 core + `specification `__. Attributes: image_uri (str): - Required. Immutable. The URI of the Model serving container - file in the Container Registry. The container image is - ingested upon + Required. Immutable. URI of the Docker image to be used as + the custom container for serving predictions. This URI must + identify an image in Artifact Registry or Container + Registry. Learn more about the container publishing + requirements, including permissions requirements for the AI + Platform Service Agent, + `here `__. + + The container image is ingested upon ``ModelService.UploadModel``, stored internally, and this original path is afterwards not used. + + To learn about the requirements for the Docker image itself, + see `Custom container + requirements `__. command (Sequence[str]): - Immutable. The command with which the container is run. Not - executed within a shell. The Docker image's ENTRYPOINT is - used if this is not provided. Variable references - $(VAR_NAME) are expanded using the container's environment. - If a variable cannot be resolved, the reference in the input - string will be unchanged. The $(VAR_NAME) syntax can be - escaped with a double $$, ie: $$(VAR_NAME). Escaped - references will never be expanded, regardless of whether the - variable exists or not. More info: - https://tinyurl.com/y42hmlxe + Immutable. Specifies the command that runs when the + container starts. This overrides the container's + `ENTRYPOINT `__. + Specify this field as an array of executable and arguments, + similar to a Docker ``ENTRYPOINT``'s "exec" form, not its + "shell" form. + + If you do not specify this field, then the container's + ``ENTRYPOINT`` runs, in conjunction with the + ``args`` + field or the container's + ```CMD`` `__, + if either exists. If this field is not specified and the + container does not have an ``ENTRYPOINT``, then refer to the + Docker documentation about how ``CMD`` and ``ENTRYPOINT`` + `interact `__. + + If you specify this field, then you can also specify the + ``args`` field to provide additional arguments for this + command. However, if you specify this field, then the + container's ``CMD`` is ignored. See the `Kubernetes + documentation `__ about how + the ``command`` and ``args`` fields interact with a + container's ``ENTRYPOINT`` and ``CMD``. + + In this field, you can reference environment variables `set + by AI + Platform `__ + and environment variables set in the + ``env`` + field. You cannot reference environment variables set in the + Docker image. In order for environment variables to be + expanded, reference them by using the following syntax: + $(VARIABLE_NAME) Note that this differs from Bash variable + expansion, which does not use parentheses. If a variable + cannot be resolved, the reference in the input string is + used unchanged. To avoid variable expansion, you can escape + this syntax with ``$$``; for example: $$(VARIABLE_NAME) This + field corresponds to the ``command`` field of the Kubernetes + Containers `v1 core + API `__. args (Sequence[str]): - Immutable. The arguments to the command. The Docker image's - CMD is used if this is not provided. Variable references - $(VAR_NAME) are expanded using the container's environment. - If a variable cannot be resolved, the reference in the input - string will be unchanged. The $(VAR_NAME) syntax can be - escaped with a double $$, ie: $$(VAR_NAME). Escaped - references will never be expanded, regardless of whether the - variable exists or not. More info: - https://tinyurl.com/y42hmlxe + Immutable. Specifies arguments for the command that runs + when the container starts. This overrides the container's + ```CMD`` `__. + Specify this field as an array of executable and arguments, + similar to a Docker ``CMD``'s "default parameters" form. + + If you don't specify this field but do specify the + ``command`` + field, then the command from the ``command`` field runs + without any additional arguments. See the `Kubernetes + documentation `__ about how + the ``command`` and ``args`` fields interact with a + container's ``ENTRYPOINT`` and ``CMD``. + + If you don't specify this field and don't specify the + ``command`` field, then the container's + ```ENTRYPOINT`` `__ + and ``CMD`` determine what runs based on their default + behavior. See the Docker documentation about how ``CMD`` and + ``ENTRYPOINT`` `interact `__. + + In this field, you can reference environment variables `set + by AI + Platform `__ + and environment variables set in the + ``env`` + field. You cannot reference environment variables set in the + Docker image. In order for environment variables to be + expanded, reference them by using the following syntax: + $(VARIABLE_NAME) Note that this differs from Bash variable + expansion, which does not use parentheses. If a variable + cannot be resolved, the reference in the input string is + used unchanged. To avoid variable expansion, you can escape + this syntax with ``$$``; for example: $$(VARIABLE_NAME) This + field corresponds to the ``args`` field of the Kubernetes + Containers `v1 core + API `__. env (Sequence[~.env_var.EnvVar]): - Immutable. The environment variables that are - to be present in the container. + Immutable. List of environment variables to set in the + container. After the container starts running, code running + in the container can read these environment variables. + + Additionally, the + ``command`` + and + ``args`` + fields can reference these variables. Later entries in this + list can also reference earlier entries. For example, the + following example sets the variable ``VAR_2`` to have the + value ``foo bar``: + + .. code:: json + + [ + { + "name": "VAR_1", + "value": "foo" + }, + { + "name": "VAR_2", + "value": "$(VAR_1) bar" + } + ] + + If you switch the order of the variables in the example, + then the expansion does not occur. + + This field corresponds to the ``env`` field of the + Kubernetes Containers `v1 core + API `__. ports (Sequence[~.model.Port]): - Immutable. Declaration of ports that are - exposed by the container. This field is - primarily informational, it gives AI Platform - information about the network connections the - container uses. Listing or not a port here has - no impact on whether the port is actually - exposed, any port listening on the default - "0.0.0.0" address inside a container will be - accessible from the network. + Immutable. List of ports to expose from the container. AI + Platform sends any prediction requests that it receives to + the first port on this list. AI Platform also sends + `liveness and health + checks `__ to + this port. + + If you do not specify this field, it defaults to following + value: + + .. code:: json + + [ + { + "containerPort": 8080 + } + ] + + AI Platform does not use ports other than the first one + listed. This field corresponds to the ``ports`` field of the + Kubernetes Containers `v1 core + API `__. predict_route (str): - Immutable. An HTTP path to send prediction - requests to the container, and which must be - supported by it. If not specified a default HTTP - path will be used by AI Platform. + Immutable. HTTP path on the container to send prediction + requests to. AI Platform forwards requests sent using + ``projects.locations.endpoints.predict`` + to this path on the container's IP address and port. AI + Platform then returns the container's response in the API + response. + + For example, if you set this field to ``/foo``, then when AI + Platform receives a prediction request, it forwards the + request body in a POST request to the following URL on the + container: localhost:PORT/foo PORT refers to the first value + of this ``ModelContainerSpec``'s + ``ports`` + field. + + If you don't specify this field, it defaults to the + following value when you [deploy this Model to an + Endpoint][google.cloud.aiplatform.v1beta1.EndpointService.DeployModel]: + /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict + The placeholders in this value are replaced as follows: + + - ENDPOINT: The last segment (following ``endpoints/``)of + the Endpoint.name][] field of the Endpoint where this + Model has been deployed. (AI Platform makes this value + available to your container code as the + ```AIP_ENDPOINT_ID`` `__ + environment variable.) + + - DEPLOYED_MODEL: + ``DeployedModel.id`` + of the ``DeployedModel``. (AI Platform makes this value + available to your container code as the + ```AIP_DEPLOYED_MODEL_ID`` environment + variable `__.) health_route (str): - Immutable. An HTTP path to send health check - requests to the container, and which must be - supported by it. If not specified a standard - HTTP path will be used by AI Platform. + Immutable. HTTP path on the container to send health checkss + to. AI Platform intermittently sends GET requests to this + path on the container's IP address and port to check that + the container is healthy. Read more about `health + checks `__. + + For example, if you set this field to ``/bar``, then AI + Platform intermittently sends a GET request to the following + URL on the container: localhost:PORT/bar PORT refers to the + first value of this ``ModelContainerSpec``'s + ``ports`` + field. + + If you don't specify this field, it defaults to the + following value when you [deploy this Model to an + Endpoint][google.cloud.aiplatform.v1beta1.EndpointService.DeployModel]: + /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict + The placeholders in this value are replaced as follows: + + - ENDPOINT: The last segment (following ``endpoints/``)of + the Endpoint.name][] field of the Endpoint where this + Model has been deployed. (AI Platform makes this value + available to your container code as the + ```AIP_ENDPOINT_ID`` `__ + environment variable.) + + - DEPLOYED_MODEL: + ``DeployedModel.id`` + of the ``DeployedModel``. (AI Platform makes this value + available to your container code as the + ```AIP_DEPLOYED_MODEL_ID`` `__ + environment variable.) """ image_uri = proto.Field(proto.STRING, number=1) + command = proto.RepeatedField(proto.STRING, number=2) + args = proto.RepeatedField(proto.STRING, number=3) + env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) + ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) + predict_route = proto.Field(proto.STRING, number=6) + health_route = proto.Field(proto.STRING, number=7) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py index 13f49e963e..b768ed978e 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py @@ -68,10 +68,15 @@ class ModelEvaluation(proto.Message): """ name = proto.Field(proto.STRING, number=1) + metrics_schema_uri = proto.Field(proto.STRING, number=2) + metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + slice_dimensions = proto.RepeatedField(proto.STRING, number=5) + model_explanation = proto.Field( proto.MESSAGE, number=8, message=explanation.ModelExplanation, ) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py index fe8dc19754..1039d32c1f 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py @@ -36,7 +36,7 @@ class ModelEvaluationSlice(proto.Message): name (str): Output only. The resource name of the ModelEvaluationSlice. - slice (~.model_evaluation_slice.ModelEvaluationSlice.Slice): + slice_ (~.model_evaluation_slice.ModelEvaluationSlice.Slice): Output only. The slice of the test data that is used to evaluate the Model. metrics_schema_uri (str): @@ -74,12 +74,17 @@ class Slice(proto.Message): """ dimension = proto.Field(proto.STRING, number=1) + value = proto.Field(proto.STRING, number=2) name = proto.Field(proto.STRING, number=1) - slice = proto.Field(proto.MESSAGE, number=2, message=Slice,) + + slice_ = proto.Field(proto.MESSAGE, number=2, message=Slice,) + metrics_schema_uri = proto.Field(proto.STRING, number=3) + metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index 42706490c1..3cfb17ad2c 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -64,6 +64,7 @@ class UploadModelRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) @@ -133,9 +134,13 @@ class ListModelsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -157,6 +162,7 @@ def raw_page(self): return self models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) + next_page_token = proto.Field(proto.STRING, number=2) @@ -176,6 +182,7 @@ class UpdateModelRequest(proto.Message): """ model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -236,14 +243,17 @@ class OutputConfig(proto.Message): """ export_format_id = proto.Field(proto.STRING, number=1) + artifact_destination = proto.Field( proto.MESSAGE, number=3, message=io.GcsDestination, ) + image_destination = proto.Field( proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, ) name = proto.Field(proto.STRING, number=1) + output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) @@ -278,11 +288,13 @@ class OutputInfo(proto.Message): """ artifact_output_uri = proto.Field(proto.STRING, number=2) + image_output_uri = proto.Field(proto.STRING, number=3) generic_metadata = proto.Field( proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) + output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) @@ -331,9 +343,13 @@ class ListModelEvaluationsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -358,6 +374,7 @@ def raw_page(self): model_evaluations = proto.RepeatedField( proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -403,9 +420,13 @@ class ListModelEvaluationSlicesRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -430,6 +451,7 @@ def raw_page(self): model_evaluation_slices = proto.RepeatedField( proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, ) + next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/operation.py b/google/cloud/aiplatform_v1beta1/types/operation.py index 3451fb9c8c..68fb0daead 100644 --- a/google/cloud/aiplatform_v1beta1/types/operation.py +++ b/google/cloud/aiplatform_v1beta1/types/operation.py @@ -43,13 +43,17 @@ class GenericOperationMetadata(proto.Message): created. update_time (~.timestamp.Timestamp): Output only. Time when the operation was - updated for the last time. + updated for the last time. If the operation has + finished (successfully or not), this is the + finish time. """ partial_failures = proto.RepeatedField( proto.MESSAGE, number=1, message=status.Status, ) + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) @@ -62,7 +66,7 @@ class DeleteOperationMetadata(proto.Message): """ generic_metadata = proto.Field( - proto.MESSAGE, number=1, message=GenericOperationMetadata, + proto.MESSAGE, number=1, message="GenericOperationMetadata", ) diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index 7ba3638c51..9f0856732d 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -51,6 +51,7 @@ class CreateTrainingPipelineRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + training_pipeline = proto.Field( proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, ) @@ -108,9 +109,13 @@ class ListTrainingPipelinesRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + filter = proto.Field(proto.STRING, number=2) + page_size = proto.Field(proto.INT32, number=3) + page_token = proto.Field(proto.STRING, number=4) + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) @@ -135,6 +140,7 @@ def raw_page(self): training_pipelines = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, ) + next_page_token = proto.Field(proto.STRING, number=2) diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index e937a6d8e4..b000f88bf8 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -64,7 +64,9 @@ class PredictRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) + instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) + parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) @@ -86,6 +88,7 @@ class PredictResponse(proto.Message): """ predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) + deployed_model_id = proto.Field(proto.STRING, number=2) @@ -124,8 +127,11 @@ class ExplainRequest(proto.Message): """ endpoint = proto.Field(proto.STRING, number=1) + instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) + parameters = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + deployed_model_id = proto.Field(proto.STRING, number=3) @@ -135,8 +141,8 @@ class ExplainResponse(proto.Message): Attributes: explanations (Sequence[~.explanation.Explanation]): - The explanations of the [Model's - predictions][PredictionResponse.predictions][]. + The explanations of the Model's + ``PredictResponse.predictions``. It has the same number of elements as ``instances`` @@ -144,12 +150,19 @@ class ExplainResponse(proto.Message): deployed_model_id (str): ID of the Endpoint's DeployedModel that served this explanation. + predictions (Sequence[~.struct.Value]): + The predictions that are the output of the predictions call. + Same as + ``PredictResponse.predictions``. """ explanations = proto.RepeatedField( proto.MESSAGE, number=1, message=explanation.Explanation, ) + deployed_model_id = proto.Field(proto.STRING, number=2) + predictions = proto.RepeatedField(proto.MESSAGE, number=3, message=struct.Value,) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py index 21ab5f9c47..4ac8c6a709 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool.py @@ -55,9 +55,13 @@ class SpecialistPool(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + specialist_managers_count = proto.Field(proto.INT32, number=3) + specialist_manager_emails = proto.RepeatedField(proto.STRING, number=4) + pending_data_labeling_jobs = proto.RepeatedField(proto.STRING, number=5) diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py index 02f0dac96f..724f7165a6 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py @@ -52,6 +52,7 @@ class CreateSpecialistPoolRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + specialist_pool = proto.Field( proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, ) @@ -108,8 +109,11 @@ class ListSpecialistPoolsRequest(proto.Message): """ parent = proto.Field(proto.STRING, number=1) + page_size = proto.Field(proto.INT32, number=2) + page_token = proto.Field(proto.STRING, number=3) + read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) @@ -132,6 +136,7 @@ def raw_page(self): specialist_pools = proto.RepeatedField( proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) + next_page_token = proto.Field(proto.STRING, number=2) @@ -152,6 +157,7 @@ class DeleteSpecialistPoolRequest(proto.Message): """ name = proto.Field(proto.STRING, number=1) + force = proto.Field(proto.BOOL, number=2) @@ -171,6 +177,7 @@ class UpdateSpecialistPoolRequest(proto.Message): specialist_pool = proto.Field( proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, ) + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) @@ -189,6 +196,7 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): """ specialist_pool = proto.Field(proto.STRING, number=1) + generic_metadata = proto.Field( proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, ) diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 55f20127ef..2d6f4ae8c3 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -81,14 +81,21 @@ class Parameter(proto.Message): """ parameter_id = proto.Field(proto.STRING, number=1) + value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) id = proto.Field(proto.STRING, number=2) + state = proto.Field(proto.ENUM, number=3, enum=State,) + parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) + final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + custom_job = proto.Field(proto.STRING, number=11) @@ -130,6 +137,7 @@ class GoalType(proto.Enum): MINIMIZE = 2 metric_id = proto.Field(proto.STRING, number=1) + goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) class ParameterSpec(proto.Message): @@ -151,6 +159,12 @@ class ParameterSpec(proto.Message): scale_type (~.study.StudySpec.ParameterSpec.ScaleType): How the parameter should be scaled. Leave unset for ``CATEGORICAL`` parameters. + conditional_parameter_specs (Sequence[~.study.StudySpec.ParameterSpec.ConditionalParameterSpec]): + A conditional parameter node is active if the parameter's + value matches the conditional node's parent_value_condition. + + If two items in conditional_parameter_specs have the same + name, they must have disjoint parent_value_condition. """ class ScaleType(proto.Enum): @@ -173,6 +187,7 @@ class DoubleValueSpec(proto.Message): """ min_value = proto.Field(proto.DOUBLE, number=1) + max_value = proto.Field(proto.DOUBLE, number=2) class IntegerValueSpec(proto.Message): @@ -188,6 +203,7 @@ class IntegerValueSpec(proto.Message): """ min_value = proto.Field(proto.INT64, number=1) + max_value = proto.Field(proto.INT64, number=2) class CategoricalValueSpec(proto.Message): @@ -215,29 +231,135 @@ class DiscreteValueSpec(proto.Message): values = proto.RepeatedField(proto.DOUBLE, number=1) + class ConditionalParameterSpec(proto.Message): + r"""Represents a parameter spec with condition from its parent + parameter. + + Attributes: + parent_discrete_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition): + The spec for matching values from a parent parameter of + ``DISCRETE`` type. + parent_int_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition): + The spec for matching values from a parent parameter of + ``INTEGER`` type. + parent_categorical_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition): + The spec for matching values from a parent parameter of + ``CATEGORICAL`` type. + parameter_spec (~.study.StudySpec.ParameterSpec): + Required. The spec for a conditional + parameter. + """ + + class DiscreteValueCondition(proto.Message): + r"""Represents the spec to match discrete values from parent + parameter. + + Attributes: + values (Sequence[float]): + Required. Matches values of the parent parameter of + 'DISCRETE' type. All values must exist in + ``discrete_value_spec`` of parent parameter. + + The Epsilon of the value matching is 1e-10. + """ + + values = proto.RepeatedField(proto.DOUBLE, number=1) + + class IntValueCondition(proto.Message): + r"""Represents the spec to match integer values from parent + parameter. + + Attributes: + values (Sequence[int]): + Required. Matches values of the parent parameter of + 'INTEGER' type. All values must lie in + ``integer_value_spec`` of parent parameter. + """ + + values = proto.RepeatedField(proto.INT64, number=1) + + class CategoricalValueCondition(proto.Message): + r"""Represents the spec to match categorical values from parent + parameter. + + Attributes: + values (Sequence[str]): + Required. Matches values of the parent parameter of + 'CATEGORICAL' type. All values must exist in + ``categorical_value_spec`` of parent parameter. + """ + + values = proto.RepeatedField(proto.STRING, number=1) + + parent_discrete_values = proto.Field( + proto.MESSAGE, + number=2, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition", + ) + + parent_int_values = proto.Field( + proto.MESSAGE, + number=3, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition", + ) + + parent_categorical_values = proto.Field( + proto.MESSAGE, + number=4, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition", + ) + + parameter_spec = proto.Field( + proto.MESSAGE, number=1, message="StudySpec.ParameterSpec", + ) + double_value_spec = proto.Field( - proto.MESSAGE, number=2, message="StudySpec.ParameterSpec.DoubleValueSpec", + proto.MESSAGE, + number=2, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DoubleValueSpec", ) + integer_value_spec = proto.Field( - proto.MESSAGE, number=3, message="StudySpec.ParameterSpec.IntegerValueSpec", + proto.MESSAGE, + number=3, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.IntegerValueSpec", ) + categorical_value_spec = proto.Field( proto.MESSAGE, number=4, + oneof="parameter_value_spec", message="StudySpec.ParameterSpec.CategoricalValueSpec", ) + discrete_value_spec = proto.Field( proto.MESSAGE, number=5, + oneof="parameter_value_spec", message="StudySpec.ParameterSpec.DiscreteValueSpec", ) + parameter_id = proto.Field(proto.STRING, number=1) + scale_type = proto.Field( proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", ) + conditional_parameter_specs = proto.RepeatedField( + proto.MESSAGE, + number=10, + message="StudySpec.ParameterSpec.ConditionalParameterSpec", + ) + metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) + parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) + algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) @@ -270,9 +392,11 @@ class Metric(proto.Message): """ metric_id = proto.Field(proto.STRING, number=1) + value = proto.Field(proto.DOUBLE, number=2) step_count = proto.Field(proto.INT64, number=2) + metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index bb32b7b787..f1f0debaf9 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -143,18 +143,31 @@ class TrainingPipeline(proto.Message): """ name = proto.Field(proto.STRING, number=1) + display_name = proto.Field(proto.STRING, number=2) + input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) + training_task_definition = proto.Field(proto.STRING, number=4) + training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) + training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) + model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) + state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + labels = proto.MapField(proto.STRING, proto.STRING, number=15) @@ -177,17 +190,53 @@ class InputDataConfig(proto.Message): Split based on the timestamp of the input data pieces. gcs_destination (~.io.GcsDestination): - The Google Cloud Storage location. + The Google Cloud Storage location where the training data is + to be written to. In the given directory a new directory + will be created with name: + ``dataset---`` + where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 + format. All training input data will be written into that + directory. The AI Platform environment variables representing Google Cloud Storage data URIs will always be represented in the Google Cloud Storage wildcard format to support sharded - data. e.g.: "gs://.../training-\* + data. e.g.: "gs://.../training-*.jsonl" + + - AIP_DATA_FORMAT = "jsonl" for non-tabular data, "csv" for + tabular data + - AIP_TRAINING_DATA_URI = + + "gcs_destination/dataset---/training-*.${AIP_DATA_FORMAT}" + + - AIP_VALIDATION_DATA_URI = + + "gcs_destination/dataset---/validation-*.${AIP_DATA_FORMAT}" + + - AIP_TEST_DATA_URI = + + "gcs_destination/dataset---/test-*.${AIP_DATA_FORMAT}". + bigquery_destination (~.io.BigQueryDestination): + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. - - AIP_DATA_FORMAT = "jsonl". - - AIP_TRAINING_DATA_URI = "gcs_destination/training-*" - - AIP_VALIDATION_DATA_URI = "gcs_destination/validation-*" - - AIP_TEST_DATA_URI = "gcs_destination/test-*". + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI = + + "bigquery_destination.dataset\_\ **\ .training" + + - AIP_VALIDATION_DATA_URI = + + "bigquery_destination.dataset\_\ **\ .validation" + + - AIP_TEST_DATA_URI = + "bigquery_destination.dataset\_\ **\ .test". dataset_id (str): Required. The ID of the Dataset in the same Project and Location which data will be used to train the Model. The @@ -237,13 +286,34 @@ class InputDataConfig(proto.Message): ``annotation_schema_uri``. """ - fraction_split = proto.Field(proto.MESSAGE, number=2, message="FractionSplit",) - filter_split = proto.Field(proto.MESSAGE, number=3, message="FilterSplit",) - predefined_split = proto.Field(proto.MESSAGE, number=4, message="PredefinedSplit",) - timestamp_split = proto.Field(proto.MESSAGE, number=5, message="TimestampSplit",) - gcs_destination = proto.Field(proto.MESSAGE, number=8, message=io.GcsDestination,) + fraction_split = proto.Field( + proto.MESSAGE, number=2, oneof="split", message="FractionSplit", + ) + + filter_split = proto.Field( + proto.MESSAGE, number=3, oneof="split", message="FilterSplit", + ) + + predefined_split = proto.Field( + proto.MESSAGE, number=4, oneof="split", message="PredefinedSplit", + ) + + timestamp_split = proto.Field( + proto.MESSAGE, number=5, oneof="split", message="TimestampSplit", + ) + + gcs_destination = proto.Field( + proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, + ) + + bigquery_destination = proto.Field( + proto.MESSAGE, number=10, oneof="destination", message=io.BigQueryDestination, + ) + dataset_id = proto.Field(proto.STRING, number=1) + annotations_filter = proto.Field(proto.STRING, number=6) + annotation_schema_uri = proto.Field(proto.STRING, number=9) @@ -269,7 +339,9 @@ class FractionSplit(proto.Message): """ training_fraction = proto.Field(proto.DOUBLE, number=1) + validation_fraction = proto.Field(proto.DOUBLE, number=2) + test_fraction = proto.Field(proto.DOUBLE, number=3) @@ -312,7 +384,9 @@ class FilterSplit(proto.Message): """ training_filter = proto.Field(proto.STRING, number=1) + validation_filter = proto.Field(proto.STRING, number=2) + test_filter = proto.Field(proto.STRING, number=3) @@ -363,8 +437,11 @@ class TimestampSplit(proto.Message): """ training_fraction = proto.Field(proto.DOUBLE, number=1) + validation_fraction = proto.Field(proto.DOUBLE, number=2) + test_fraction = proto.Field(proto.DOUBLE, number=3) + key = proto.Field(proto.STRING, number=4) diff --git a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py index ce868edc27..710e4a6d16 100644 --- a/google/cloud/aiplatform_v1beta1/types/user_action_reference.py +++ b/google/cloud/aiplatform_v1beta1/types/user_action_reference.py @@ -44,8 +44,10 @@ class UserActionReference(proto.Message): "/google.cloud.aiplatform.v1alpha1.DatasetService.CreateDataset". """ - operation = proto.Field(proto.STRING, number=1) - data_labeling_job = proto.Field(proto.STRING, number=2) + operation = proto.Field(proto.STRING, number=1, oneof="reference") + + data_labeling_job = proto.Field(proto.STRING, number=2, oneof="reference") + method = proto.Field(proto.STRING, number=3) diff --git a/mypy.ini b/mypy.ini index f23e6b533a..4505b48543 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,3 @@ [mypy] -python_version = 3.5 +python_version = 3.6 namespace_packages = True diff --git a/noxfile.py b/noxfile.py index 615e2c6793..1797beebfd 100644 --- a/noxfile.py +++ b/noxfile.py @@ -72,7 +72,9 @@ def default(session): # Install all test dependencies, then install this package in-place. session.install("asyncmock", "pytest-asyncio") - session.install("mock", "pytest", "pytest-cov") + session.install( + "mock", "pytest", "pytest-cov", + ) session.install("-e", ".") # Run py.test against the unit tests. diff --git a/scripts/fixup_aiplatform_v1beta1_keywords.py b/scripts/fixup_aiplatform_v1beta1_keywords.py deleted file mode 100644 index 26a51a02c1..0000000000 --- a/scripts/fixup_aiplatform_v1beta1_keywords.py +++ /dev/null @@ -1,236 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import argparse -import os -import libcst as cst -import pathlib -import sys -from typing import (Any, Callable, Dict, List, Sequence, Tuple) - - -def partition( - predicate: Callable[[Any], bool], - iterator: Sequence[Any] -) -> Tuple[List[Any], List[Any]]: - """A stable, out-of-place partition.""" - results = ([], []) - - for i in iterator: - results[int(predicate(i))].append(i) - - # Returns trueList, falseList - return results[1], results[0] - - -class aiplatformCallTransformer(cst.CSTTransformer): - CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') - METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'cancel_batch_prediction_job': ('name', ), - 'cancel_custom_job': ('name', ), - 'cancel_data_labeling_job': ('name', ), - 'cancel_hyperparameter_tuning_job': ('name', ), - 'cancel_training_pipeline': ('name', ), - 'create_batch_prediction_job': ('parent', 'batch_prediction_job', ), - 'create_custom_job': ('parent', 'custom_job', ), - 'create_data_labeling_job': ('parent', 'data_labeling_job', ), - 'create_dataset': ('parent', 'dataset', ), - 'create_endpoint': ('parent', 'endpoint', ), - 'create_hyperparameter_tuning_job': ('parent', 'hyperparameter_tuning_job', ), - 'create_specialist_pool': ('parent', 'specialist_pool', ), - 'create_training_pipeline': ('parent', 'training_pipeline', ), - 'delete_batch_prediction_job': ('name', ), - 'delete_custom_job': ('name', ), - 'delete_data_labeling_job': ('name', ), - 'delete_dataset': ('name', ), - 'delete_endpoint': ('name', ), - 'delete_hyperparameter_tuning_job': ('name', ), - 'delete_model': ('name', ), - 'delete_specialist_pool': ('name', 'force', ), - 'delete_training_pipeline': ('name', ), - 'deploy_model': ('endpoint', 'deployed_model', 'traffic_split', ), - 'explain': ('endpoint', 'instances', 'parameters', 'deployed_model_id', ), - 'export_data': ('name', 'export_config', ), - 'export_model': ('name', 'output_config', ), - 'get_annotation_spec': ('name', 'read_mask', ), - 'get_batch_prediction_job': ('name', ), - 'get_custom_job': ('name', ), - 'get_data_labeling_job': ('name', ), - 'get_dataset': ('name', 'read_mask', ), - 'get_endpoint': ('name', ), - 'get_hyperparameter_tuning_job': ('name', ), - 'get_model': ('name', ), - 'get_model_evaluation': ('name', ), - 'get_model_evaluation_slice': ('name', ), - 'get_specialist_pool': ('name', ), - 'get_training_pipeline': ('name', ), - 'import_data': ('name', 'import_configs', ), - 'list_annotations': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', 'order_by', ), - 'list_batch_prediction_jobs': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_custom_jobs': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_data_items': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', 'order_by', ), - 'list_data_labeling_jobs': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', 'order_by', ), - 'list_datasets': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', 'order_by', ), - 'list_endpoints': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_hyperparameter_tuning_jobs': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_model_evaluations': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_model_evaluation_slices': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_models': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'list_specialist_pools': ('parent', 'page_size', 'page_token', 'read_mask', ), - 'list_training_pipelines': ('parent', 'filter', 'page_size', 'page_token', 'read_mask', ), - 'predict': ('endpoint', 'instances', 'parameters', ), - 'undeploy_model': ('endpoint', 'deployed_model_id', 'traffic_split', ), - 'update_dataset': ('dataset', 'update_mask', ), - 'update_endpoint': ('endpoint', 'update_mask', ), - 'update_model': ('model', 'update_mask', ), - 'update_specialist_pool': ('specialist_pool', 'update_mask', ), - 'upload_model': ('parent', 'model', ), - - } - - def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: - try: - key = original.func.attr.value - kword_params = self.METHOD_TO_PARAMS[key] - except (AttributeError, KeyError): - # Either not a method from the API or too convoluted to be sure. - return updated - - # If the existing code is valid, keyword args come after positional args. - # Therefore, all positional args must map to the first parameters. - args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) - if any(k.keyword.value == "request" for k in kwargs): - # We've already fixed this file, don't fix it again. - return updated - - kwargs, ctrl_kwargs = partition( - lambda a: not a.keyword.value in self.CTRL_PARAMS, - kwargs - ) - - args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] - ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) - for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) - - request_arg = cst.Arg( - value=cst.Dict([ - cst.DictElement( - cst.SimpleString("'{}'".format(name)), - cst.Element(value=arg.value) - ) - # Note: the args + kwargs looks silly, but keep in mind that - # the control parameters had to be stripped out, and that - # those could have been passed positionally or by keyword. - for name, arg in zip(kword_params, args + kwargs)]), - keyword=cst.Name("request") - ) - - return updated.with_changes( - args=[request_arg] + ctrl_kwargs - ) - - -def fix_files( - in_dir: pathlib.Path, - out_dir: pathlib.Path, - *, - transformer=aiplatformCallTransformer(), -): - """Duplicate the input dir to the output dir, fixing file method calls. - - Preconditions: - * in_dir is a real directory - * out_dir is a real, empty directory - """ - pyfile_gen = ( - pathlib.Path(os.path.join(root, f)) - for root, _, files in os.walk(in_dir) - for f in files if os.path.splitext(f)[1] == ".py" - ) - - for fpath in pyfile_gen: - with open(fpath, 'r') as f: - src = f.read() - - # Parse the code and insert method call fixes. - tree = cst.parse_module(src) - updated = tree.visit(transformer) - - # Create the path and directory structure for the new file. - updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) - updated_path.parent.mkdir(parents=True, exist_ok=True) - - # Generate the updated source file at the corresponding path. - with open(updated_path, 'w') as f: - f.write(updated.code) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="""Fix up source that uses the aiplatform client library. - -The existing sources are NOT overwritten but are copied to output_dir with changes made. - -Note: This tool operates at a best-effort level at converting positional - parameters in client method calls to keyword based parameters. - Cases where it WILL FAIL include - A) * or ** expansion in a method call. - B) Calls via function or method alias (includes free function calls) - C) Indirect or dispatched calls (e.g. the method is looked up dynamically) - - These all constitute false negatives. The tool will also detect false - positives when an API method shares a name with another method. -""") - parser.add_argument( - '-d', - '--input-directory', - required=True, - dest='input_dir', - help='the input directory to walk for python files to fix up', - ) - parser.add_argument( - '-o', - '--output-directory', - required=True, - dest='output_dir', - help='the directory to output files fixed via un-flattening', - ) - args = parser.parse_args() - input_dir = pathlib.Path(args.input_dir) - output_dir = pathlib.Path(args.output_dir) - if not input_dir.is_dir(): - print( - f"input directory '{input_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if not output_dir.is_dir(): - print( - f"output directory '{output_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if os.listdir(output_dir): - print( - f"output directory '{output_dir}' is not empty", - file=sys.stderr, - ) - sys.exit(-1) - - fix_files(input_dir, output_dir) diff --git a/setup.py b/setup.py index 7cf86480e5..1b0066a279 100644 --- a/setup.py +++ b/setup.py @@ -15,37 +15,29 @@ # limitations under the License. # -import io -import os import setuptools # type: ignore -version = "0.2.0" - -package_root = os.path.abspath(os.path.dirname(__file__)) - -readme_filename = os.path.join(package_root, "README.rst") -with io.open(readme_filename, encoding="utf-8") as readme_file: - readme = readme_file.read() - setuptools.setup( name="google-cloud-aiplatform", - version=version, - long_description=readme, + version="0.3.0", + packages=setuptools.PEP420PackageFinder.find(), + namespace_packages=("google", "google.cloud"), author="Google LLC", author_email="googleapis-packages@google.com", license="Apache 2.0", - url="https://github.com/googleapis/python-documentai", - packages=setuptools.PEP420PackageFinder.find(), - namespace_packages=("google", "google.cloud"), + url="https://github.com/googleapis/python-aiplatform", platforms="Posix; MacOS X; Windows", include_package_data=True, install_requires=( "google-api-core[grpc] >= 1.22.2, < 2.0.0dev", "libcst >= 0.2.5", - "proto-plus >= 1.4.0", + "proto-plus >= 1.10.1", + "mock >= 4.0.2", + "google-cloud-storage >= 1.26.0, < 2.0.0dev", ), python_requires=">=3.6", + scripts=[], classifiers=[ "Development Status :: 4 - Beta", "Intended Audience :: Developers", diff --git a/synth.metadata b/synth.metadata index 2f581d2593..9399d8c2e3 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,14 +4,14 @@ "git": { "name": ".", "remote": "https://github.com/dizcology/python-aiplatform.git", - "sha": "7e83ff65457e88aa155e68ddd959933a68da46af" + "sha": "81da030c0af8902fd54c8e7b5e92255a532d0efb" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "487eba79f8260e34205d8ceb1ebcc65685085e19" + "sha": "ba9918cd22874245b55734f57470c719b577e591" } } ], diff --git a/synth.py b/synth.py index 31c3e11493..935bf20fce 100644 --- a/synth.py +++ b/synth.py @@ -35,17 +35,17 @@ # version="v1beta1", # bazel_target="//google/cloud/aiplatform/v1beta1:aiplatform-v1beta1-py", # ) -library = gapic.py_library("aiplatform", "v1beta1", generator_version="0.20") +library = gapic.py_library("aiplatform", "v1beta1") s.move( library, excludes=[ - ".kokoro", "setup.py", "README.rst", "docs/index.rst", + "scripts/fixup_aiplatform_v1beta1_keywords.py", "google/cloud/aiplatform/__init__.py", - "tests/unit/aiplatform_v1beta1/test_prediction_service.py", + "tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py", ], ) @@ -53,9 +53,6 @@ # Patch the library # ---------------------------------------------------------------------------- -# https://github.com/googleapis/gapic-generator-python/issues/336 -s.replace("**/client.py", " operation.from_gapic", " ga_operation.from_gapic") - s.replace( "**/client.py", "client_options: ClientOptions = ", @@ -69,6 +66,13 @@ "request.instances.extend(instances)", ) +# https://github.com/googleapis/gapic-generator-python/issues/672 +s.replace( + "google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py", + "request.traffic_split.extend\(traffic_split\)", + "request.traffic_split = traffic_split", +) + # post processing to fix the generated reference doc from synthtool import transforms as st import re @@ -125,7 +129,11 @@ templated_files = common.py_library(cov_level=99, microgenerator=True) s.move( - templated_files, excludes=[".coveragerc"] + templated_files, + excludes=[ + ".coveragerc", + ".kokoro/samples/**" + ] ) # the microgenerator has a good coveragerc file # Don't treat docs warnings as errors diff --git a/tests/unit/gapic/aiplatform_v1beta1/__init__.py b/tests/unit/gapic/aiplatform_v1beta1/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py new file mode 100644 index 0000000000..51022d9fb7 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -0,0 +1,3498 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.dataset_service import ( + DatasetServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers +from google.cloud.aiplatform_v1beta1.services.dataset_service import transports +from google.cloud.aiplatform_v1beta1.types import annotation +from google.cloud.aiplatform_v1beta1.types import annotation_spec +from google.cloud.aiplatform_v1beta1.types import data_item +from google.cloud.aiplatform_v1beta1.types import dataset +from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert DatasetServiceClient._get_default_mtls_endpoint(None) is None + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient] +) +def test_dataset_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_dataset_service_client_get_transport_class(): + transport = DatasetServiceClient.get_transport_class() + assert transport == transports.DatasetServiceGrpcTransport + + transport = DatasetServiceClient.get_transport_class("grpc") + assert transport == transports.DatasetServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +def test_dataset_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_dataset_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_dataset_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = DatasetServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_dataset( + transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.CreateDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_dataset_from_dict(): + test_create_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.CreateDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_dataset_async_from_dict(): + await test_create_dataset_async(request_type=dict) + + +def test_create_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.CreateDatasetRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_dataset_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.CreateDatasetRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_dataset( + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + +def test_create_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_dataset( + dataset_service.CreateDatasetRequest(), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_dataset_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_dataset( + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + +@pytest.mark.asyncio +async def test_create_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_dataset( + dataset_service.CreateDatasetRequest(), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), + ) + + +def test_get_dataset( + transport: str = "grpc", request_type=dataset_service.GetDatasetRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + + response = client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetDatasetRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, dataset.Dataset) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.etag == "etag_value" + + +def test_get_dataset_from_dict(): + test_get_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) + + response = await client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, dataset.Dataset) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_dataset_async_from_dict(): + await test_get_dataset_async(request_type=dict) + + +def test_get_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetDatasetRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + call.return_value = dataset.Dataset() + + client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_dataset_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetDatasetRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) + + await client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_dataset( + dataset_service.GetDatasetRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_dataset_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_dataset( + dataset_service.GetDatasetRequest(), name="name_value", + ) + + +def test_update_dataset( + transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + + response = client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.UpdateDatasetRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_dataset.Dataset) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.etag == "etag_value" + + +def test_update_dataset_from_dict(): + test_update_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) + + response = await client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.UpdateDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_dataset.Dataset) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_update_dataset_async_from_dict(): + await test_update_dataset_async(request_type=dict) + + +def test_update_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.UpdateDatasetRequest() + request.dataset.name = "dataset.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + call.return_value = gca_dataset.Dataset() + + client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] + + +@pytest.mark.asyncio +async def test_update_dataset_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.UpdateDatasetRequest() + request.dataset.name = "dataset.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) + + await client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] + + +def test_update_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_dataset( + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_dataset( + dataset_service.UpdateDatasetRequest(), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_dataset_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_dataset( + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_dataset( + dataset_service.UpdateDatasetRequest(), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_list_datasets( + transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDatasetsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListDatasetsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_datasets_from_dict(): + test_list_datasets(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_datasets_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDatasetsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatasetsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_datasets_async_from_dict(): + await test_list_datasets_async(request_type=dict) + + +def test_list_datasets_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDatasetsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value = dataset_service.ListDatasetsResponse() + + client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_datasets_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDatasetsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) + + await client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_datasets_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_datasets(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_datasets_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_datasets( + dataset_service.ListDatasetsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_datasets_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_datasets(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_datasets_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_datasets( + dataset_service.ListDatasetsRequest(), parent="parent_value", + ) + + +def test_list_datasets_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_datasets(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, dataset.Dataset) for i in results) + + +def test_list_datasets_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + pages = list(client.list_datasets(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_datasets_async_pager(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + async_pager = await client.list_datasets(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, dataset.Dataset) for i in responses) + + +@pytest.mark.asyncio +async def test_list_datasets_async_pages(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_datasets(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_dataset( + transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.DeleteDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_dataset_from_dict(): + test_delete_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.DeleteDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_dataset_async_from_dict(): + await test_delete_dataset_async(request_type=dict) + + +def test_delete_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.DeleteDatasetRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_dataset_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.DeleteDatasetRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_dataset( + dataset_service.DeleteDatasetRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_dataset_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_dataset( + dataset_service.DeleteDatasetRequest(), name="name_value", + ) + + +def test_import_data( + transport: str = "grpc", request_type=dataset_service.ImportDataRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ImportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_import_data_from_dict(): + test_import_data(request_type=dict) + + +@pytest.mark.asyncio +async def test_import_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ImportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_import_data_async_from_dict(): + await test_import_data_async(request_type=dict) + + +def test_import_data_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ImportDataRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_import_data_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ImportDataRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_import_data_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.import_data( + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] + + +def test_import_data_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.import_data( + dataset_service.ImportDataRequest(), + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + +@pytest.mark.asyncio +async def test_import_data_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.import_data( + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] + + +@pytest.mark.asyncio +async def test_import_data_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.import_data( + dataset_service.ImportDataRequest(), + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + +def test_export_data( + transport: str = "grpc", request_type=dataset_service.ExportDataRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ExportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_data_from_dict(): + test_export_data(request_type=dict) + + +@pytest.mark.asyncio +async def test_export_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ExportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_export_data_async_from_dict(): + await test_export_data_async(request_type=dict) + + +def test_export_data_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ExportDataRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_export_data_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ExportDataRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_export_data_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.export_data( + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) + + +def test_export_data_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_data( + dataset_service.ExportDataRequest(), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + +@pytest.mark.asyncio +async def test_export_data_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.export_data( + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) + + +@pytest.mark.asyncio +async def test_export_data_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.export_data( + dataset_service.ExportDataRequest(), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + +def test_list_data_items( + transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDataItemsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListDataItemsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_data_items_from_dict(): + test_list_data_items(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_data_items_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDataItemsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataItemsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_data_items_async_from_dict(): + await test_list_data_items_async(request_type=dict) + + +def test_list_data_items_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = dataset_service.ListDataItemsResponse() + + client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_data_items_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) + + await client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_data_items_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_data_items(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_data_items_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_items( + dataset_service.ListDataItemsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_data_items_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_data_items(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_data_items_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_data_items( + dataset_service.ListDataItemsRequest(), parent="parent_value", + ) + + +def test_list_data_items_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_data_items(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, data_item.DataItem) for i in results) + + +def test_list_data_items_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + pages = list(client.list_data_items(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_data_items_async_pager(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + async_pager = await client.list_data_items(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, data_item.DataItem) for i in responses) + + +@pytest.mark.asyncio +async def test_list_data_items_async_pages(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_data_items(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_annotation_spec( + transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + + response = client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetAnnotationSpecRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, annotation_spec.AnnotationSpec) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.etag == "etag_value" + + +def test_get_annotation_spec_from_dict(): + test_get_annotation_spec(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_annotation_spec_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.GetAnnotationSpecRequest, +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + ) + + response = await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetAnnotationSpecRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, annotation_spec.AnnotationSpec) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_annotation_spec_async_from_dict(): + await test_get_annotation_spec_async(request_type=dict) + + +def test_get_annotation_spec_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = annotation_spec.AnnotationSpec() + + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_annotation_spec_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) + + await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_annotation_spec_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_annotation_spec(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_annotation_spec_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_annotation_spec_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_annotation_spec(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_annotation_spec_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), name="name_value", + ) + + +def test_list_annotations( + transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListAnnotationsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListAnnotationsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_annotations_from_dict(): + test_list_annotations(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_annotations_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListAnnotationsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAnnotationsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_annotations_async_from_dict(): + await test_list_annotations_async(request_type=dict) + + +def test_list_annotations_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = dataset_service.ListAnnotationsResponse() + + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_annotations_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) + + await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_annotations_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_annotations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_annotations_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_annotations( + dataset_service.ListAnnotationsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_annotations_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_annotations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_annotations_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_annotations( + dataset_service.ListAnnotationsRequest(), parent="parent_value", + ) + + +def test_list_annotations_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_annotations(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, annotation.Annotation) for i in results) + + +def test_list_annotations_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + pages = list(client.list_annotations(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_annotations_async_pager(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + async_pager = await client.list_annotations(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, annotation.Annotation) for i in responses) + + +@pytest.mark.asyncio +async def test_list_annotations_async_pages(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_annotations(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = DatasetServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.DatasetServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) + + +def test_dataset_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.DatasetServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_dataset_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.DatasetServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_dataset", + "get_dataset", + "update_dataset", + "list_datasets", + "delete_dataset", + "import_data", + "export_data", + "list_data_items", + "get_annotation_spec", + "list_annotations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_dataset_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.DatasetServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_dataset_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.DatasetServiceTransport() + adc.assert_called_once() + + +def test_dataset_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + DatasetServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_dataset_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_dataset_service_host_no_port(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_dataset_service_host_with_port(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_dataset_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_dataset_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.DatasetServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_dataset_service_grpc_lro_client(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_dataset_service_grpc_lro_async_client(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_annotation_path(): + project = "squid" + location = "clam" + dataset = "whelk" + data_item = "octopus" + annotation = "oyster" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) + actual = DatasetServiceClient.annotation_path( + project, location, dataset, data_item, annotation + ) + assert expected == actual + + +def test_parse_annotation_path(): + expected = { + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", + } + path = DatasetServiceClient.annotation_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_annotation_path(path) + assert expected == actual + + +def test_annotation_spec_path(): + project = "scallop" + location = "abalone" + dataset = "squid" + annotation_spec = "clam" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) + actual = DatasetServiceClient.annotation_spec_path( + project, location, dataset, annotation_spec + ) + assert expected == actual + + +def test_parse_annotation_spec_path(): + expected = { + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", + } + path = DatasetServiceClient.annotation_spec_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_annotation_spec_path(path) + assert expected == actual + + +def test_data_item_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + data_item = "nautilus" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) + actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) + assert expected == actual + + +def test_parse_data_item_path(): + expected = { + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", + } + path = DatasetServiceClient.data_item_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_data_item_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "whelk" + location = "octopus" + dataset = "oyster" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = DatasetServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + } + path = DatasetServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "winkle" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = DatasetServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nautilus", + } + path = DatasetServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "scallop" + + expected = "folders/{folder}".format(folder=folder,) + actual = DatasetServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "abalone", + } + path = DatasetServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "squid" + + expected = "organizations/{organization}".format(organization=organization,) + actual = DatasetServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "clam", + } + path = DatasetServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "whelk" + + expected = "projects/{project}".format(project=project,) + actual = DatasetServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "octopus", + } + path = DatasetServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "oyster" + location = "nudibranch" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = DatasetServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "cuttlefish", + "location": "mussel", + } + path = DatasetServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = DatasetServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py new file mode 100644 index 0000000000..93c35a7a2a --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -0,0 +1,2582 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( + EndpointServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports +from google.cloud.aiplatform_v1beta1.types import accelerator_type +from google.cloud.aiplatform_v1beta1.types import endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1beta1.types import endpoint_service +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert EndpointServiceClient._get_default_mtls_endpoint(None) is None + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient] +) +def test_endpoint_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_endpoint_service_client_get_transport_class(): + transport = EndpointServiceClient.get_transport_class() + assert transport == transports.EndpointServiceGrpcTransport + + transport = EndpointServiceClient.get_transport_class("grpc") + assert transport == transports.EndpointServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +def test_endpoint_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "true", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "false", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_endpoint_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_endpoint_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = EndpointServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_endpoint( + transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.CreateEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_endpoint_from_dict(): + test_create_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.CreateEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_endpoint_async_from_dict(): + await test_create_endpoint_async(request_type=dict) + + +def test_create_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.CreateEndpointRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.CreateEndpointRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_endpoint( + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + +def test_create_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_endpoint( + endpoint_service.CreateEndpointRequest(), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_endpoint_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_endpoint( + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + +@pytest.mark.asyncio +async def test_create_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_endpoint( + endpoint_service.CreateEndpointRequest(), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + +def test_get_endpoint( + transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.GetEndpointRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, endpoint.Endpoint) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +def test_get_endpoint_from_dict(): + test_get_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) + + response = await client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.GetEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, endpoint.Endpoint) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_endpoint_async_from_dict(): + await test_get_endpoint_async(request_type=dict) + + +def test_get_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.GetEndpointRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + call.return_value = endpoint.Endpoint() + + client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.GetEndpointRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) + + await client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_endpoint( + endpoint_service.GetEndpointRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_endpoint_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_endpoint( + endpoint_service.GetEndpointRequest(), name="name_value", + ) + + +def test_list_endpoints( + transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.ListEndpointsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListEndpointsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_endpoints_from_dict(): + test_list_endpoints(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_endpoints_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.ListEndpointsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListEndpointsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_endpoints_async_from_dict(): + await test_list_endpoints_async(request_type=dict) + + +def test_list_endpoints_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.ListEndpointsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value = endpoint_service.ListEndpointsResponse() + + client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_endpoints_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.ListEndpointsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) + + await client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_endpoints_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_endpoints(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_endpoints_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_endpoints( + endpoint_service.ListEndpointsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_endpoints_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_endpoints(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_endpoints_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_endpoints( + endpoint_service.ListEndpointsRequest(), parent="parent_value", + ) + + +def test_list_endpoints_pager(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_endpoints(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, endpoint.Endpoint) for i in results) + + +def test_list_endpoints_pages(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + pages = list(client.list_endpoints(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_endpoints_async_pager(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + async_pager = await client.list_endpoints(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, endpoint.Endpoint) for i in responses) + + +@pytest.mark.asyncio +async def test_list_endpoints_async_pages(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_endpoints(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_update_endpoint( + transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UpdateEndpointRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_endpoint.Endpoint) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +def test_update_endpoint_from_dict(): + test_update_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) + + response = await client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UpdateEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_endpoint.Endpoint) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_update_endpoint_async_from_dict(): + await test_update_endpoint_async(request_type=dict) + + +def test_update_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UpdateEndpointRequest() + request.endpoint.name = "endpoint.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value = gca_endpoint.Endpoint() + + client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] + + +@pytest.mark.asyncio +async def test_update_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UpdateEndpointRequest() + request.endpoint.name = "endpoint.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) + + await client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] + + +def test_update_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_endpoint( + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_endpoint( + endpoint_service.UpdateEndpointRequest(), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_endpoint_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_endpoint( + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_endpoint( + endpoint_service.UpdateEndpointRequest(), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_endpoint( + transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeleteEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_endpoint_from_dict(): + test_delete_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeleteEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_endpoint_async_from_dict(): + await test_delete_endpoint_async(request_type=dict) + + +def test_delete_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeleteEndpointRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeleteEndpointRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_endpoint( + endpoint_service.DeleteEndpointRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_endpoint_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_endpoint( + endpoint_service.DeleteEndpointRequest(), name="name_value", + ) + + +def test_deploy_model( + transport: str = "grpc", request_type=endpoint_service.DeployModelRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_deploy_model_from_dict(): + test_deploy_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_deploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_deploy_model_async_from_dict(): + await test_deploy_model_async(request_type=dict) + + +def test_deploy_model_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeployModelRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_deploy_model_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeployModelRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +def test_deploy_model_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.deploy_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + + assert args[0].traffic_split == {"key_value": 541} + + +def test_deploy_model_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.deploy_model( + endpoint_service.DeployModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + +@pytest.mark.asyncio +async def test_deploy_model_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.deploy_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + + assert args[0].traffic_split == {"key_value": 541} + + +@pytest.mark.asyncio +async def test_deploy_model_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.deploy_model( + endpoint_service.DeployModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + +def test_undeploy_model( + transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UndeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_undeploy_model_from_dict(): + test_undeploy_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_undeploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UndeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_undeploy_model_async_from_dict(): + await test_undeploy_model_async(request_type=dict) + + +def test_undeploy_model_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UndeployModelRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_undeploy_model_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UndeployModelRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +def test_undeploy_model_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.undeploy_model( + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].deployed_model_id == "deployed_model_id_value" + + assert args[0].traffic_split == {"key_value": 541} + + +def test_undeploy_model_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.undeploy_model( + endpoint_service.UndeployModelRequest(), + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + +@pytest.mark.asyncio +async def test_undeploy_model_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.undeploy_model( + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].deployed_model_id == "deployed_model_id_value" + + assert args[0].traffic_split == {"key_value": 541} + + +@pytest.mark.asyncio +async def test_undeploy_model_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.undeploy_model( + endpoint_service.UndeployModelRequest(), + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = EndpointServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.EndpointServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) + + +def test_endpoint_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.EndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_endpoint_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.EndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_endpoint", + "get_endpoint", + "list_endpoints", + "update_endpoint", + "delete_endpoint", + "deploy_model", + "undeploy_model", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_endpoint_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.EndpointServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_endpoint_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.EndpointServiceTransport() + adc.assert_called_once() + + +def test_endpoint_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + EndpointServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_endpoint_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_endpoint_service_host_no_port(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_endpoint_service_host_with_port(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_endpoint_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_endpoint_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.EndpointServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_endpoint_service_grpc_lro_client(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_service_grpc_lro_async_client(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + actual = EndpointServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } + path = EndpointServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_endpoint_path(path) + assert expected == actual + + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = EndpointServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } + path = EndpointServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = EndpointServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = EndpointServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = EndpointServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = EndpointServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = EndpointServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = EndpointServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = EndpointServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = EndpointServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = EndpointServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = EndpointServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = EndpointServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py new file mode 100644 index 0000000000..f08d84bd2f --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -0,0 +1,6095 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceAsyncClient +from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient +from google.cloud.aiplatform_v1beta1.services.job_service import pagers +from google.cloud.aiplatform_v1beta1.services.job_service import transports +from google.cloud.aiplatform_v1beta1.types import accelerator_type +from google.cloud.aiplatform_v1beta1.types import batch_prediction_job +from google.cloud.aiplatform_v1beta1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1beta1.types import completion_stats +from google.cloud.aiplatform_v1beta1.types import custom_job +from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1beta1.types import data_labeling_job +from google.cloud.aiplatform_v1beta1.types import ( + data_labeling_job as gca_data_labeling_job, +) +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1beta1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import job_service +from google.cloud.aiplatform_v1beta1.types import job_state +from google.cloud.aiplatform_v1beta1.types import machine_resources +from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import study +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 as gp_any # type: ignore +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert JobServiceClient._get_default_mtls_endpoint(None) is None + assert ( + JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) +def test_job_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_job_service_client_get_transport_class(): + transport = JobServiceClient.get_transport_class() + assert transport == transports.JobServiceGrpcTransport + + transport = JobServiceClient.get_transport_class("grpc") + assert transport == transports.JobServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +def test_job_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_job_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_job_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_custom_job( + transport: str = "grpc", request_type=job_service.CreateCustomJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateCustomJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_custom_job.CustomJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_custom_job_from_dict(): + test_create_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_custom_job.CustomJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_create_custom_job_async_from_dict(): + await test_create_custom_job_async(request_type=dict) + + +def test_create_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateCustomJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + call.return_value = gca_custom_job.CustomJob() + + client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_custom_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateCustomJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) + + await client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_custom_job( + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + + +def test_create_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_custom_job( + job_service.CreateCustomJobRequest(), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_custom_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_custom_job( + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + + +@pytest.mark.asyncio +async def test_create_custom_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_custom_job( + job_service.CreateCustomJobRequest(), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + +def test_get_custom_job( + transport: str = "grpc", request_type=job_service.GetCustomJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetCustomJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, custom_job.CustomJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_custom_job_from_dict(): + test_get_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, custom_job.CustomJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_custom_job_async_from_dict(): + await test_get_custom_job_async(request_type=dict) + + +def test_get_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value = custom_job.CustomJob() + + client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_custom_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) + + await client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_custom_job( + job_service.GetCustomJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_custom_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_custom_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_custom_job( + job_service.GetCustomJobRequest(), name="name_value", + ) + + +def test_list_custom_jobs( + transport: str = "grpc", request_type=job_service.ListCustomJobsRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListCustomJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListCustomJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_custom_jobs_from_dict(): + test_list_custom_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async( + transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) + ) + + response = await client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListCustomJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCustomJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_from_dict(): + await test_list_custom_jobs_async(request_type=dict) + + +def test_list_custom_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListCustomJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value = job_service.ListCustomJobsResponse() + + client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_custom_jobs_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListCustomJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) + + await client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_custom_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_custom_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_custom_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_custom_jobs( + job_service.ListCustomJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_custom_jobs_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_custom_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_custom_jobs_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_custom_jobs( + job_service.ListCustomJobsRequest(), parent="parent_value", + ) + + +def test_list_custom_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_custom_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, custom_job.CustomJob) for i in results) + + +def test_list_custom_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + pages = list(client.list_custom_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_pager(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + async_pager = await client.list_custom_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, custom_job.CustomJob) for i in responses) + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_pages(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_custom_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_custom_job( + transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_custom_job_from_dict(): + test_delete_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_custom_job_async_from_dict(): + await test_delete_custom_job_async(request_type=dict) + + +def test_delete_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_custom_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_custom_job( + job_service.DeleteCustomJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_custom_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_custom_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_custom_job( + job_service.DeleteCustomJobRequest(), name="name_value", + ) + + +def test_cancel_custom_job( + transport: str = "grpc", request_type=job_service.CancelCustomJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelCustomJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_custom_job_from_dict(): + test_cancel_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelCustomJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_custom_job_async_from_dict(): + await test_cancel_custom_job_async(request_type=dict) + + +def test_cancel_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_custom_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_custom_job( + job_service.CancelCustomJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_custom_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_custom_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_custom_job( + job_service.CancelCustomJobRequest(), name="name_value", + ) + + +def test_create_data_labeling_job( + transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + + response = client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_data_labeling_job.DataLabelingJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.datasets == ["datasets_value"] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == "instruction_uri_value" + + assert response.inputs_schema_uri == "inputs_schema_uri_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ["specialist_pools_value"] + + +def test_create_data_labeling_job_from_dict(): + test_create_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateDataLabelingJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) + + response = await client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_data_labeling_job.DataLabelingJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.datasets == ["datasets_value"] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == "instruction_uri_value" + + assert response.inputs_schema_uri == "inputs_schema_uri_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ["specialist_pools_value"] + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_async_from_dict(): + await test_create_data_labeling_job_async(request_type=dict) + + +def test_create_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateDataLabelingJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + call.return_value = gca_data_labeling_job.DataLabelingJob() + + client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateDataLabelingJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) + + await client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_data_labeling_job( + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) + + +def test_create_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_data_labeling_job( + job_service.CreateDataLabelingJobRequest(), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_data_labeling_job( + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_data_labeling_job( + job_service.CreateDataLabelingJobRequest(), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + +def test_get_data_labeling_job( + transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + + response = client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, data_labeling_job.DataLabelingJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.datasets == ["datasets_value"] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == "instruction_uri_value" + + assert response.inputs_schema_uri == "inputs_schema_uri_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ["specialist_pools_value"] + + +def test_get_data_labeling_job_from_dict(): + test_get_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) + + response = await client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, data_labeling_job.DataLabelingJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.datasets == ["datasets_value"] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == "instruction_uri_value" + + assert response.inputs_schema_uri == "inputs_schema_uri_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ["specialist_pools_value"] + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_async_from_dict(): + await test_get_data_labeling_job_async(request_type=dict) + + +def test_get_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = data_labeling_job.DataLabelingJob() + + client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) + + await client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_data_labeling_job( + job_service.GetDataLabelingJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_data_labeling_job( + job_service.GetDataLabelingJobRequest(), name="name_value", + ) + + +def test_list_data_labeling_jobs( + transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListDataLabelingJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListDataLabelingJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_data_labeling_jobs_from_dict(): + test_list_data_labeling_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListDataLabelingJobsRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListDataLabelingJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_from_dict(): + await test_list_data_labeling_jobs_async(request_type=dict) + + +def test_list_data_labeling_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListDataLabelingJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = job_service.ListDataLabelingJobsResponse() + + client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListDataLabelingJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) + + await client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_data_labeling_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_data_labeling_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_data_labeling_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_labeling_jobs( + job_service.ListDataLabelingJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_data_labeling_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_data_labeling_jobs( + job_service.ListDataLabelingJobsRequest(), parent="parent_value", + ) + + +def test_list_data_labeling_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_data_labeling_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) + + +def test_list_data_labeling_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_data_labeling_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_pager(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_data_labeling_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_pages(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_data_labeling_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_data_labeling_job( + transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_data_labeling_job_from_dict(): + test_delete_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteDataLabelingJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_async_from_dict(): + await test_delete_data_labeling_job_async(request_type=dict) + + +def test_delete_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_data_labeling_job( + job_service.DeleteDataLabelingJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_data_labeling_job( + job_service.DeleteDataLabelingJobRequest(), name="name_value", + ) + + +def test_cancel_data_labeling_job( + transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_data_labeling_job_from_dict(): + test_cancel_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelDataLabelingJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_async_from_dict(): + await test_cancel_data_labeling_job_async(request_type=dict) + + +def test_cancel_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_data_labeling_job( + job_service.CancelDataLabelingJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_data_labeling_job( + job_service.CancelDataLabelingJobRequest(), name="name_value", + ) + + +def test_create_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_hyperparameter_tuning_job_from_dict(): + test_create_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_async_from_dict(): + await test_create_hyperparameter_tuning_job_async(request_type=dict) + + +def test_create_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateHyperparameterTuningJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateHyperparameterTuningJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) + + await client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_hyperparameter_tuning_job( + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) + + +def test_create_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_hyperparameter_tuning_job( + job_service.CreateHyperparameterTuningJobRequest(), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_hyperparameter_tuning_job( + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_hyperparameter_tuning_job( + job_service.CreateHyperparameterTuningJobRequest(), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + +def test_get_hyperparameter_tuning_job( + transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_hyperparameter_tuning_job_from_dict(): + test_get_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetHyperparameterTuningJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_async_from_dict(): + await test_get_hyperparameter_tuning_job_async(request_type=dict) + + +def test_get_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) + + await client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_hyperparameter_tuning_job( + job_service.GetHyperparameterTuningJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_hyperparameter_tuning_job( + job_service.GetHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_list_hyperparameter_tuning_jobs( + transport: str = "grpc", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_hyperparameter_tuning_jobs_from_dict(): + test_list_hyperparameter_tuning_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_from_dict(): + await test_list_hyperparameter_tuning_jobs_async(request_type=dict) + + +def test_list_hyperparameter_tuning_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListHyperparameterTuningJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListHyperparameterTuningJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) + + await client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_hyperparameter_tuning_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_hyperparameter_tuning_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_hyperparameter_tuning_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_hyperparameter_tuning_jobs( + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_hyperparameter_tuning_jobs( + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + ) + + +def test_list_hyperparameter_tuning_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_hyperparameter_tuning_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results + ) + + +def test_list_hyperparameter_tuning_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_pager(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_hyperparameter_tuning_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_pages(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in ( + await client.list_hyperparameter_tuning_jobs(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_hyperparameter_tuning_job_from_dict(): + test_delete_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_async_from_dict(): + await test_delete_hyperparameter_tuning_job_async(request_type=dict) + + +def test_delete_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_hyperparameter_tuning_job( + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_hyperparameter_tuning_job( + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_cancel_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_hyperparameter_tuning_job_from_dict(): + test_cancel_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_async_from_dict(): + await test_cancel_hyperparameter_tuning_job_async(request_type=dict) + + +def test_cancel_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_hyperparameter_tuning_job( + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_hyperparameter_tuning_job( + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_create_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.model == "model_value" + + assert response.generate_explanation is True + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_batch_prediction_job_from_dict(): + test_create_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateBatchPredictionJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.model == "model_value" + + assert response.generate_explanation is True + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_async_from_dict(): + await test_create_batch_prediction_job_async(request_type=dict) + + +def test_create_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateBatchPredictionJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateBatchPredictionJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) + + await client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_batch_prediction_job( + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) + + +def test_create_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_batch_prediction_job( + job_service.CreateBatchPredictionJobRequest(), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_batch_prediction_job( + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_batch_prediction_job( + job_service.CreateBatchPredictionJobRequest(), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + +def test_get_batch_prediction_job( + transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, batch_prediction_job.BatchPredictionJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.model == "model_value" + + assert response.generate_explanation is True + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_batch_prediction_job_from_dict(): + test_get_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetBatchPredictionJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + generate_explanation=True, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, batch_prediction_job.BatchPredictionJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.model == "model_value" + + assert response.generate_explanation is True + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_async_from_dict(): + await test_get_batch_prediction_job_async(request_type=dict) + + +def test_get_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = batch_prediction_job.BatchPredictionJob() + + client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) + + await client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_batch_prediction_job( + job_service.GetBatchPredictionJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_batch_prediction_job( + job_service.GetBatchPredictionJobRequest(), name="name_value", + ) + + +def test_list_batch_prediction_jobs( + transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListBatchPredictionJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListBatchPredictionJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_batch_prediction_jobs_from_dict(): + test_list_batch_prediction_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListBatchPredictionJobsRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListBatchPredictionJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_from_dict(): + await test_list_batch_prediction_jobs_async(request_type=dict) + + +def test_list_batch_prediction_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListBatchPredictionJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = job_service.ListBatchPredictionJobsResponse() + + client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListBatchPredictionJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) + + await client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_batch_prediction_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_batch_prediction_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_batch_prediction_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_batch_prediction_jobs( + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_batch_prediction_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_batch_prediction_jobs( + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + ) + + +def test_list_batch_prediction_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_batch_prediction_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results + ) + + +def test_list_batch_prediction_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_batch_prediction_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_pager(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_batch_prediction_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_pages(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_batch_prediction_job( + transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_batch_prediction_job_from_dict(): + test_delete_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteBatchPredictionJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async_from_dict(): + await test_delete_batch_prediction_job_async(request_type=dict) + + +def test_delete_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), name="name_value", + ) + + +def test_cancel_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_batch_prediction_job_from_dict(): + test_cancel_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelBatchPredictionJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async_from_dict(): + await test_cancel_batch_prediction_job_async(request_type=dict) + + +def test_cancel_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = JobServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.JobServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.JobServiceGrpcTransport,) + + +def test_job_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.JobServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_job_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.JobServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_custom_job", + "get_custom_job", + "list_custom_jobs", + "delete_custom_job", + "cancel_custom_job", + "create_data_labeling_job", + "get_data_labeling_job", + "list_data_labeling_jobs", + "delete_data_labeling_job", + "cancel_data_labeling_job", + "create_hyperparameter_tuning_job", + "get_hyperparameter_tuning_job", + "list_hyperparameter_tuning_jobs", + "delete_hyperparameter_tuning_job", + "cancel_hyperparameter_tuning_job", + "create_batch_prediction_job", + "get_batch_prediction_job", + "list_batch_prediction_jobs", + "delete_batch_prediction_job", + "cancel_batch_prediction_job", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_job_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.JobServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_job_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.JobServiceTransport() + adc.assert_called_once() + + +def test_job_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + JobServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_job_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.JobServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_job_service_host_no_port(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_job_service_host_with_port(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_job_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.JobServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_job_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.JobServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_job_service_grpc_lro_client(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_job_service_grpc_lro_async_client(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_batch_prediction_job_path(): + project = "squid" + location = "clam" + batch_prediction_job = "whelk" + + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, location=location, batch_prediction_job=batch_prediction_job, + ) + actual = JobServiceClient.batch_prediction_job_path( + project, location, batch_prediction_job + ) + assert expected == actual + + +def test_parse_batch_prediction_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", + } + path = JobServiceClient.batch_prediction_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_batch_prediction_job_path(path) + assert expected == actual + + +def test_custom_job_path(): + project = "cuttlefish" + location = "mussel" + custom_job = "winkle" + + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + actual = JobServiceClient.custom_job_path(project, location, custom_job) + assert expected == actual + + +def test_parse_custom_job_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", + } + path = JobServiceClient.custom_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_custom_job_path(path) + assert expected == actual + + +def test_data_labeling_job_path(): + project = "squid" + location = "clam" + data_labeling_job = "whelk" + + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + actual = JobServiceClient.data_labeling_job_path( + project, location, data_labeling_job + ) + assert expected == actual + + +def test_parse_data_labeling_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", + } + path = JobServiceClient.data_labeling_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_data_labeling_job_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = JobServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } + path = JobServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_hyperparameter_tuning_job_path(): + project = "squid" + location = "clam" + hyperparameter_tuning_job = "whelk" + + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) + actual = JobServiceClient.hyperparameter_tuning_job_path( + project, location, hyperparameter_tuning_job + ) + assert expected == actual + + +def test_parse_hyperparameter_tuning_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "hyperparameter_tuning_job": "nudibranch", + } + path = JobServiceClient.hyperparameter_tuning_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) + assert expected == actual + + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = JobServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } + path = JobServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = JobServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = JobServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = JobServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = JobServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = JobServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = JobServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = JobServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = JobServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = JobServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = JobServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = JobServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py new file mode 100644 index 0000000000..85e6a2d362 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -0,0 +1,1721 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.migration_service import ( + MigrationServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.migration_service import pagers +from google.cloud.aiplatform_v1beta1.services.migration_service import transports +from google.cloud.aiplatform_v1beta1.types import migratable_resource +from google.cloud.aiplatform_v1beta1.types import migration_service +from google.longrunning import operations_pb2 +from google.oauth2 import service_account + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert MigrationServiceClient._get_default_mtls_endpoint(None) is None + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient] +) +def test_migration_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_migration_service_client_get_transport_class(): + transport = MigrationServiceClient.get_transport_class() + assert transport == transports.MigrationServiceGrpcTransport + + transport = MigrationServiceClient.get_transport_class("grpc") + assert transport == transports.MigrationServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "true", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "false", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_migration_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_migration_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = MigrationServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_search_migratable_resources( + transport: str = "grpc", + request_type=migration_service.SearchMigratableResourcesRequest, +): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + + response = client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.SearchMigratableResourcesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.SearchMigratableResourcesPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_search_migratable_resources_from_dict(): + test_search_migratable_resources(request_type=dict) + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.SearchMigratableResourcesRequest, +): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.SearchMigratableResourcesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_from_dict(): + await test_search_migratable_resources_async(request_type=dict) + + +def test_search_migratable_resources_field_headers(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.SearchMigratableResourcesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = migration_service.SearchMigratableResourcesResponse() + + client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_search_migratable_resources_field_headers_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.SearchMigratableResourcesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) + + await client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_search_migratable_resources_flattened(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.search_migratable_resources(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_search_migratable_resources_flattened_error(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.search_migratable_resources( + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_search_migratable_resources_flattened_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.search_migratable_resources(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_search_migratable_resources_flattened_error_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.search_migratable_resources( + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + ) + + +def test_search_migratable_resources_pager(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token="abc", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], next_page_token="def", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.search_migratable_resources(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in results + ) + + +def test_search_migratable_resources_pages(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token="abc", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], next_page_token="def", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + pages = list(client.search_migratable_resources(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_pager(): + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token="abc", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], next_page_token="def", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + async_pager = await client.search_migratable_resources(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in responses + ) + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_pages(): + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token="abc", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], next_page_token="def", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.search_migratable_resources(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_batch_migrate_resources( + transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest +): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.BatchMigrateResourcesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_batch_migrate_resources_from_dict(): + test_batch_migrate_resources(request_type=dict) + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.BatchMigrateResourcesRequest, +): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.BatchMigrateResourcesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_async_from_dict(): + await test_batch_migrate_resources_async(request_type=dict) + + +def test_batch_migrate_resources_field_headers(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.BatchMigrateResourcesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_field_headers_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.BatchMigrateResourcesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_batch_migrate_resources_flattened(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_migrate_resources( + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] + + +def test_batch_migrate_resources_flattened_error(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_migrate_resources( + migration_service.BatchMigrateResourcesRequest(), + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], + ) + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_flattened_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_migrate_resources( + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_flattened_error_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_migrate_resources( + migration_service.BatchMigrateResourcesRequest(), + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = MigrationServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.MigrationServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) + + +def test_migration_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.MigrationServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_migration_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.MigrationServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "search_migratable_resources", + "batch_migrate_resources", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_migration_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.MigrationServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_migration_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.MigrationServiceTransport() + adc.assert_called_once() + + +def test_migration_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + MigrationServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_migration_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_migration_service_host_no_port(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_migration_service_host_with_port(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_migration_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_migration_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.MigrationServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_migration_service_grpc_lro_client(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_migration_service_grpc_lro_async_client(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_annotated_dataset_path(): + project = "squid" + dataset = "clam" + annotated_dataset = "whelk" + + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + actual = MigrationServiceClient.annotated_dataset_path( + project, dataset, annotated_dataset + ) + assert expected == actual + + +def test_parse_annotated_dataset_path(): + expected = { + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", + } + path = MigrationServiceClient.annotated_dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_annotated_dataset_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "squid" + dataset = "clam" + + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "whelk", + "dataset": "octopus", + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "oyster" + location = "nudibranch" + dataset = "cuttlefish" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "mussel", + "location": "winkle", + "dataset": "nautilus", + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_model_path(): + project = "scallop" + location = "abalone" + model = "squid" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = MigrationServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "clam", + "location": "whelk", + "model": "octopus", + } + path = MigrationServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_model_path(path) + assert expected == actual + + +def test_model_path(): + project = "oyster" + location = "nudibranch" + model = "cuttlefish" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = MigrationServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "mussel", + "location": "winkle", + "model": "nautilus", + } + path = MigrationServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_model_path(path) + assert expected == actual + + +def test_version_path(): + project = "scallop" + model = "abalone" + version = "squid" + + expected = "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) + actual = MigrationServiceClient.version_path(project, model, version) + assert expected == actual + + +def test_parse_version_path(): + expected = { + "project": "clam", + "model": "whelk", + "version": "octopus", + } + path = MigrationServiceClient.version_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_version_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "oyster" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = MigrationServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + } + path = MigrationServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "cuttlefish" + + expected = "folders/{folder}".format(folder=folder,) + actual = MigrationServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + } + path = MigrationServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "winkle" + + expected = "organizations/{organization}".format(organization=organization,) + actual = MigrationServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + } + path = MigrationServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "scallop" + + expected = "projects/{project}".format(project=project,) + actual = MigrationServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + } + path = MigrationServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "squid" + location = "clam" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = MigrationServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + } + path = MigrationServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = MigrationServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py new file mode 100644 index 0000000000..d05698a46a --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -0,0 +1,3676 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.model_service import ( + ModelServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient +from google.cloud.aiplatform_v1beta1.services.model_service import pagers +from google.cloud.aiplatform_v1beta1.services.model_service import transports +from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import env_var +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import model as gca_model +from google.cloud.aiplatform_v1beta1.types import model_evaluation +from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice +from google.cloud.aiplatform_v1beta1.types import model_service +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert ModelServiceClient._get_default_mtls_endpoint(None) is None + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) +def test_model_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_model_service_client_get_transport_class(): + transport = ModelServiceClient.get_transport_class() + assert transport == transports.ModelServiceGrpcTransport + + transport = ModelServiceClient.get_transport_class("grpc") + assert transport == transports.ModelServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_model_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_model_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_upload_model( + transport: str = "grpc", request_type=model_service.UploadModelRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UploadModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_upload_model_from_dict(): + test_upload_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_upload_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UploadModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_upload_model_async_from_dict(): + await test_upload_model_async(request_type=dict) + + +def test_upload_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UploadModelRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_upload_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UploadModelRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_upload_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.upload_model( + parent="parent_value", model=gca_model.Model(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].model == gca_model.Model(name="name_value") + + +def test_upload_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.upload_model( + model_service.UploadModelRequest(), + parent="parent_value", + model=gca_model.Model(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_upload_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.upload_model( + parent="parent_value", model=gca_model.Model(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].model == gca_model.Model(name="name_value") + + +@pytest.mark.asyncio +async def test_upload_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.upload_model( + model_service.UploadModelRequest(), + parent="parent_value", + model=gca_model.Model(name="name_value"), + ) + + +def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", + ) + + response = client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, model.Model) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.training_pipeline == "training_pipeline_value" + + assert response.artifact_uri == "artifact_uri_value" + + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + + assert response.etag == "etag_value" + + +def test_get_model_from_dict(): + test_get_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) + + response = await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.training_pipeline == "training_pipeline_value" + + assert response.artifact_uri == "artifact_uri_value" + + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_model_async_from_dict(): + await test_get_model_async(request_type=dict) + + +def test_get_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value = model.Model() + + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model( + model_service.GetModelRequest(), name="name_value", + ) + + +def test_list_models( + transport: str = "grpc", request_type=model_service.ListModelsRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListModelsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_models_from_dict(): + test_list_models(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse(next_page_token="next_page_token_value",) + ) + + response = await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_models_async_from_dict(): + await test_list_models_async(request_type=dict) + + +def test_list_models_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = model_service.ListModelsResponse() + + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_models_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) + + await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_models_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_models(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_models_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_models_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_models(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_models_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_models( + model_service.ListModelsRequest(), parent="parent_value", + ) + + +def test_list_models_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_models(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, model.Model) for i in results) + + +def test_list_models_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + pages = list(client.list_models(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_models_async_pager(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + async_pager = await client.list_models(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model.Model) for i in responses) + + +@pytest.mark.asyncio +async def test_list_models_async_pages(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_models(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_update_model( + transport: str = "grpc", request_type=model_service.UpdateModelRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", + ) + + response = client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UpdateModelRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_model.Model) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.training_pipeline == "training_pipeline_value" + + assert response.artifact_uri == "artifact_uri_value" + + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + + assert response.etag == "etag_value" + + +def test_update_model_from_dict(): + test_update_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) + + response = await client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UpdateModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_model.Model) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.training_pipeline == "training_pipeline_value" + + assert response.artifact_uri == "artifact_uri_value" + + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_update_model_async_from_dict(): + await test_update_model_async(request_type=dict) + + +def test_update_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateModelRequest() + request.model.name = "model.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + call.return_value = gca_model.Model() + + client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateModelRequest() + request.model.name = "model.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) + + await client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + + +def test_update_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_model( + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].model == gca_model.Model(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_model( + model_service.UpdateModelRequest(), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_model( + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].model == gca_model.Model(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_model( + model_service.UpdateModelRequest(), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_model( + transport: str = "grpc", request_type=model_service.DeleteModelRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.DeleteModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_model_from_dict(): + test_delete_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.DeleteModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_model_async_from_dict(): + await test_delete_model_async(request_type=dict) + + +def test_delete_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_model( + model_service.DeleteModelRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_model( + model_service.DeleteModelRequest(), name="name_value", + ) + + +def test_export_model( + transport: str = "grpc", request_type=model_service.ExportModelRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ExportModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_model_from_dict(): + test_export_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_export_model_async( + transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ExportModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_export_model_async_from_dict(): + await test_export_model_async(request_type=dict) + + +def test_export_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ExportModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_export_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ExportModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_export_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.export_model( + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) + + +def test_export_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_model( + model_service.ExportModelRequest(), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + +@pytest.mark.asyncio +async def test_export_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.export_model( + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) + + +@pytest.mark.asyncio +async def test_export_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.export_model( + model_service.ExportModelRequest(), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + +def test_get_model_evaluation( + transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + + response = client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, model_evaluation.ModelEvaluation) + + assert response.name == "name_value" + + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + assert response.slice_dimensions == ["slice_dimensions_value"] + + +def test_get_model_evaluation_from_dict(): + test_get_model_evaluation(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationRequest, +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + ) + + response = await client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation.ModelEvaluation) + + assert response.name == "name_value" + + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + assert response.slice_dimensions == ["slice_dimensions_value"] + + +@pytest.mark.asyncio +async def test_get_model_evaluation_async_from_dict(): + await test_get_model_evaluation_async(request_type=dict) + + +def test_get_model_evaluation_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = model_evaluation.ModelEvaluation() + + client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_model_evaluation_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) + + await client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_evaluation_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model_evaluation(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_model_evaluation_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_evaluation( + model_service.GetModelEvaluationRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model_evaluation(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_model_evaluation_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model_evaluation( + model_service.GetModelEvaluationRequest(), name="name_value", + ) + + +def test_list_model_evaluations( + transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListModelEvaluationsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_model_evaluations_from_dict(): + test_list_model_evaluations(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationsRequest, +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_from_dict(): + await test_list_model_evaluations_async(request_type=dict) + + +def test_list_model_evaluations_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = model_service.ListModelEvaluationsResponse() + + client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_model_evaluations_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) + + await client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_model_evaluations_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_model_evaluations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_model_evaluations_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_evaluations( + model_service.ListModelEvaluationsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_model_evaluations_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_model_evaluations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_model_evaluations_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_model_evaluations( + model_service.ListModelEvaluationsRequest(), parent="parent_value", + ) + + +def test_list_model_evaluations_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_model_evaluations(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) + + +def test_list_model_evaluations_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_evaluations(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_pager(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_model_evaluations(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_pages(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_model_evaluations(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_model_evaluation_slice( + transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + + response = client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationSliceRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) + + assert response.name == "name_value" + + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + +def test_get_model_evaluation_slice_from_dict(): + test_get_model_evaluation_slice(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationSliceRequest, +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + ) + + response = await client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationSliceRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) + + assert response.name == "name_value" + + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_async_from_dict(): + await test_get_model_evaluation_slice_async(request_type=dict) + + +def test_get_model_evaluation_slice_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationSliceRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationSliceRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) + + await client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_evaluation_slice_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model_evaluation_slice(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_model_evaluation_slice_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_evaluation_slice( + model_service.GetModelEvaluationSliceRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model_evaluation_slice(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model_evaluation_slice( + model_service.GetModelEvaluationSliceRequest(), name="name_value", + ) + + +def test_list_model_evaluation_slices( + transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationSlicesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListModelEvaluationSlicesPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_model_evaluation_slices_from_dict(): + test_list_model_evaluation_slices(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationSlicesRequest, +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationSlicesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_from_dict(): + await test_list_model_evaluation_slices_async(request_type=dict) + + +def test_list_model_evaluation_slices_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationSlicesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationSlicesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) + + await client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_model_evaluation_slices_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_model_evaluation_slices(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_model_evaluation_slices_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_evaluation_slices( + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_model_evaluation_slices(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_model_evaluation_slices( + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + ) + + +def test_list_model_evaluation_slices_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_model_evaluation_slices(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results + ) + + +def test_list_model_evaluation_slices_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_evaluation_slices(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_pager(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_model_evaluation_slices(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_pages(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in ( + await client.list_model_evaluation_slices(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = ModelServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ModelServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) + + +def test_model_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.ModelServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_model_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.ModelServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "upload_model", + "get_model", + "list_models", + "update_model", + "delete_model", + "export_model", + "get_model_evaluation", + "list_model_evaluations", + "get_model_evaluation_slice", + "list_model_evaluation_slices", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_model_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_model_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport() + adc.assert_called_once() + + +def test_model_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + ModelServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_model_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_model_service_host_no_port(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_model_service_host_with_port(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_model_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_model_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_service_grpc_lro_client(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_model_service_grpc_lro_async_client(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + actual = ModelServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } + path = ModelServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_endpoint_path(path) + assert expected == actual + + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = ModelServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } + path = ModelServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_path(path) + assert expected == actual + + +def test_model_evaluation_path(): + project = "squid" + location = "clam" + model = "whelk" + evaluation = "octopus" + + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) + actual = ModelServiceClient.model_evaluation_path( + project, location, model, evaluation + ) + assert expected == actual + + +def test_parse_model_evaluation_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", + } + path = ModelServiceClient.model_evaluation_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_evaluation_path(path) + assert expected == actual + + +def test_model_evaluation_slice_path(): + project = "winkle" + location = "nautilus" + model = "scallop" + evaluation = "abalone" + slice = "squid" + + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) + actual = ModelServiceClient.model_evaluation_slice_path( + project, location, model, evaluation, slice + ) + assert expected == actual + + +def test_parse_model_evaluation_slice_path(): + expected = { + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", + } + path = ModelServiceClient.model_evaluation_slice_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_evaluation_slice_path(path) + assert expected == actual + + +def test_training_pipeline_path(): + project = "cuttlefish" + location = "mussel" + training_pipeline = "winkle" + + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = ModelServiceClient.training_pipeline_path( + project, location, training_pipeline + ) + assert expected == actual + + +def test_parse_training_pipeline_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", + } + path = ModelServiceClient.training_pipeline_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_training_pipeline_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = ModelServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = ModelServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = ModelServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = ModelServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = ModelServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = ModelServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = ModelServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = ModelServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = ModelServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = ModelServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = ModelServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py new file mode 100644 index 0000000000..ada82b91c0 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -0,0 +1,2237 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( + PipelineServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports +from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import env_var +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import explanation_metadata +from google.cloud.aiplatform_v1beta1.types import io +from google.cloud.aiplatform_v1beta1.types import model +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import pipeline_service +from google.cloud.aiplatform_v1beta1.types import pipeline_state +from google.cloud.aiplatform_v1beta1.types import training_pipeline +from google.cloud.aiplatform_v1beta1.types import ( + training_pipeline as gca_training_pipeline, +) +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 as gp_any # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert PipelineServiceClient._get_default_mtls_endpoint(None) is None + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient] +) +def test_pipeline_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_pipeline_service_client_get_transport_class(): + transport = PipelineServiceClient.get_transport_class() + assert transport == transports.PipelineServiceGrpcTransport + + transport = PipelineServiceClient.get_transport_class("grpc") + assert transport == transports.PipelineServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +def test_pipeline_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "true", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "false", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_pipeline_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_pipeline_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = PipelineServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + + response = client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_training_pipeline.TrainingPipeline) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.training_task_definition == "training_task_definition_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_create_training_pipeline_from_dict(): + test_create_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CreateTrainingPipelineRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) + + response = await client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_training_pipeline.TrainingPipeline) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.training_task_definition == "training_task_definition_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_create_training_pipeline_async_from_dict(): + await test_create_training_pipeline_async(request_type=dict) + + +def test_create_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CreateTrainingPipelineRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + call.return_value = gca_training_pipeline.TrainingPipeline() + + client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CreateTrainingPipelineRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) + + await client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_training_pipeline( + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) + + +def test_create_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_training_pipeline( + pipeline_service.CreateTrainingPipelineRequest(), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_training_pipeline( + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_training_pipeline( + pipeline_service.CreateTrainingPipelineRequest(), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + +def test_get_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + + response = client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, training_pipeline.TrainingPipeline) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.training_task_definition == "training_task_definition_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_get_training_pipeline_from_dict(): + test_get_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.GetTrainingPipelineRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) + + response = await client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, training_pipeline.TrainingPipeline) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.training_task_definition == "training_task_definition_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_training_pipeline_async_from_dict(): + await test_get_training_pipeline_async(request_type=dict) + + +def test_get_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = training_pipeline.TrainingPipeline() + + client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) + + await client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_training_pipeline( + pipeline_service.GetTrainingPipelineRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_training_pipeline( + pipeline_service.GetTrainingPipelineRequest(), name="name_value", + ) + + +def test_list_training_pipelines( + transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListTrainingPipelinesPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_training_pipelines_from_dict(): + test_list_training_pipelines(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.ListTrainingPipelinesRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_from_dict(): + await test_list_training_pipelines_async(request_type=dict) + + +def test_list_training_pipelines_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListTrainingPipelinesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_training_pipelines_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListTrainingPipelinesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) + + await client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_training_pipelines_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_training_pipelines(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_training_pipelines_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_training_pipelines( + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_training_pipelines_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_training_pipelines(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_training_pipelines_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_training_pipelines( + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + ) + + +def test_list_training_pipelines_pager(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_training_pipelines(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) + + +def test_list_training_pipelines_pages(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + pages = list(client.list_training_pipelines(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_pager(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_training_pipelines(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_pages(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_training_pipelines(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_training_pipeline_from_dict(): + test_delete_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.DeleteTrainingPipelineRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_async_from_dict(): + await test_delete_training_pipeline_async(request_type=dict) + + +def test_delete_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.DeleteTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.DeleteTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_training_pipeline( + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_training_pipeline( + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + ) + + +def test_cancel_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_training_pipeline_from_dict(): + test_cancel_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CancelTrainingPipelineRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_async_from_dict(): + await test_cancel_training_pipeline_async(request_type=dict) + + +def test_cancel_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CancelTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + call.return_value = None + + client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CancelTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_training_pipeline( + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_training_pipeline( + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = PipelineServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.PipelineServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) + + +def test_pipeline_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.PipelineServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_pipeline_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.PipelineServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_training_pipeline", + "get_training_pipeline", + "list_training_pipelines", + "delete_training_pipeline", + "cancel_training_pipeline", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_pipeline_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PipelineServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_pipeline_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PipelineServiceTransport() + adc.assert_called_once() + + +def test_pipeline_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + PipelineServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_pipeline_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_pipeline_service_host_no_port(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_pipeline_service_host_with_port(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_pipeline_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_pipeline_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.PipelineServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_pipeline_service_grpc_lro_client(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_pipeline_service_grpc_lro_async_client(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + actual = PipelineServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } + path = PipelineServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_endpoint_path(path) + assert expected == actual + + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = PipelineServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } + path = PipelineServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_model_path(path) + assert expected == actual + + +def test_training_pipeline_path(): + project = "squid" + location = "clam" + training_pipeline = "whelk" + + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = PipelineServiceClient.training_pipeline_path( + project, location, training_pipeline + ) + assert expected == actual + + +def test_parse_training_pipeline_path(): + expected = { + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", + } + path = PipelineServiceClient.training_pipeline_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_training_pipeline_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = PipelineServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = PipelineServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = PipelineServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = PipelineServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = PipelineServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = PipelineServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = PipelineServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = PipelineServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = PipelineServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = PipelineServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = PipelineServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py new file mode 100644 index 0000000000..e47e0f62c5 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -0,0 +1,1366 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + PredictionServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import ( + PredictionServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.prediction_service import transports +from google.cloud.aiplatform_v1beta1.types import explanation +from google.cloud.aiplatform_v1beta1.types import prediction_service +from google.oauth2 import service_account +from google.protobuf import struct_pb2 as struct # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert PredictionServiceClient._get_default_mtls_endpoint(None) is None + assert ( + PredictionServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", [PredictionServiceClient, PredictionServiceAsyncClient] +) +def test_prediction_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_prediction_service_client_get_transport_class(): + transport = PredictionServiceClient.get_transport_class() + assert transport == transports.PredictionServiceGrpcTransport + + transport = PredictionServiceClient.get_transport_class("grpc") + assert transport == transports.PredictionServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + PredictionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceAsyncClient), +) +def test_prediction_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(PredictionServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(PredictionServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + "true", + ), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + "false", + ), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + PredictionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_prediction_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_prediction_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_prediction_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_prediction_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = PredictionServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_predict( + transport: str = "grpc", request_type=prediction_service.PredictRequest +): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse( + deployed_model_id="deployed_model_id_value", + ) + + response = client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == prediction_service.PredictRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, prediction_service.PredictResponse) + + assert response.deployed_model_id == "deployed_model_id_value" + + +def test_predict_from_dict(): + test_predict(request_type=dict) + + +@pytest.mark.asyncio +async def test_predict_async( + transport: str = "grpc_asyncio", request_type=prediction_service.PredictRequest +): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse( + deployed_model_id="deployed_model_id_value", + ) + ) + + response = await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == prediction_service.PredictRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.PredictResponse) + + assert response.deployed_model_id == "deployed_model_id_value" + + +@pytest.mark.asyncio +async def test_predict_async_from_dict(): + await test_predict_async(request_type=dict) + + +def test_predict_field_headers(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.PredictRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value = prediction_service.PredictResponse() + + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_predict_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.PredictRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse() + ) + + await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +def test_predict_flattened(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.predict( + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] + + # https://github.com/googleapis/gapic-generator-python/issues/414 + # assert args[0].parameters == struct.Value( + # null_value=struct.NullValue.NULL_VALUE + # ) + + +def test_predict_flattened_error(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.predict( + prediction_service.PredictRequest(), + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + +@pytest.mark.asyncio +async def test_predict_flattened_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.predict( + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] + + # https://github.com/googleapis/gapic-generator-python/issues/414 + # assert args[0].parameters == struct.Value( + # null_value=struct.NullValue.NULL_VALUE + # ) + + +@pytest.mark.asyncio +async def test_predict_flattened_error_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.predict( + prediction_service.PredictRequest(), + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + ) + + +def test_explain( + transport: str = "grpc", request_type=prediction_service.ExplainRequest +): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.explain), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.ExplainResponse( + deployed_model_id="deployed_model_id_value", + ) + + response = client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == prediction_service.ExplainRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, prediction_service.ExplainResponse) + + assert response.deployed_model_id == "deployed_model_id_value" + + +def test_explain_from_dict(): + test_explain(request_type=dict) + + +@pytest.mark.asyncio +async def test_explain_async( + transport: str = "grpc_asyncio", request_type=prediction_service.ExplainRequest +): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.explain), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.ExplainResponse( + deployed_model_id="deployed_model_id_value", + ) + ) + + response = await client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == prediction_service.ExplainRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.ExplainResponse) + + assert response.deployed_model_id == "deployed_model_id_value" + + +@pytest.mark.asyncio +async def test_explain_async_from_dict(): + await test_explain_async(request_type=dict) + + +def test_explain_field_headers(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.ExplainRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.explain), "__call__") as call: + call.return_value = prediction_service.ExplainResponse() + + client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_explain_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.ExplainRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.explain), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.ExplainResponse() + ) + + await client.explain(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +def test_explain_flattened(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.explain), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.ExplainResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.explain( + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id="deployed_model_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] + + # https://github.com/googleapis/gapic-generator-python/issues/414 + # assert args[0].parameters == struct.Value( + # null_value=struct.NullValue.NULL_VALUE + # ) + + assert args[0].deployed_model_id == "deployed_model_id_value" + + +def test_explain_flattened_error(): + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.explain( + prediction_service.ExplainRequest(), + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id="deployed_model_id_value", + ) + + +@pytest.mark.asyncio +async def test_explain_flattened_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.explain), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.ExplainResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.ExplainResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.explain( + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id="deployed_model_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].instances == [ + struct.Value(null_value=struct.NullValue.NULL_VALUE) + ] + + # https://github.com/googleapis/gapic-generator-python/issues/414 + # assert args[0].parameters == struct.Value( + # null_value=struct.NullValue.NULL_VALUE + # ) + + assert args[0].deployed_model_id == "deployed_model_id_value" + + +@pytest.mark.asyncio +async def test_explain_flattened_error_async(): + client = PredictionServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.explain( + prediction_service.ExplainRequest(), + endpoint="endpoint_value", + instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], + parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), + deployed_model_id="deployed_model_id_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = PredictionServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.PredictionServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.PredictionServiceGrpcTransport,) + + +def test_prediction_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.PredictionServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_prediction_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.PredictionServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "predict", + "explain", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + +def test_prediction_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PredictionServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_prediction_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PredictionServiceTransport() + adc.assert_called_once() + + +def test_prediction_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + PredictionServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_prediction_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.PredictionServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_prediction_service_host_no_port(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_prediction_service_host_with_port(): + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_prediction_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.PredictionServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_prediction_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.PredictionServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_prediction_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_prediction_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + actual = PredictionServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } + path = PredictionServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_endpoint_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = PredictionServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = PredictionServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = PredictionServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = PredictionServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = PredictionServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = PredictionServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = PredictionServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = PredictionServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = PredictionServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = PredictionServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.PredictionServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = PredictionServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.PredictionServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = PredictionServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py new file mode 100644 index 0000000000..6c1061d588 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -0,0 +1,2267 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceAsyncClient, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( + SpecialistPoolServiceClient, +) +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers +from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports +from google.cloud.aiplatform_v1beta1.types import operation as gca_operation +from google.cloud.aiplatform_v1beta1.types import specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool +from google.cloud.aiplatform_v1beta1.types import specialist_pool_service +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient] +) +def test_specialist_pool_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_specialist_pool_service_client_get_transport_class(): + transport = SpecialistPoolServiceClient.get_transport_class() + assert transport == transports.SpecialistPoolServiceGrpcTransport + + transport = SpecialistPoolServiceClient.get_transport_class("grpc") + assert transport == transports.SpecialistPoolServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +def test_specialist_pool_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "true", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "false", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_specialist_pool_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + ssl_channel_creds = mock.Mock() + with mock.patch( + "grpc.ssl_channel_credentials", return_value=ssl_channel_creds + ): + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_ssl_channel_creds = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_ssl_channel_creds = ssl_channel_creds + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.ssl_credentials", + new_callable=mock.PropertyMock, + ) as ssl_credentials_mock: + if use_client_cert_env == "false": + is_mtls_mock.return_value = False + ssl_credentials_mock.return_value = None + expected_host = client.DEFAULT_ENDPOINT + expected_ssl_channel_creds = None + else: + is_mtls_mock.return_value = True + ssl_credentials_mock.return_value = mock.Mock() + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_ssl_channel_creds = ( + ssl_credentials_mock.return_value + ) + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + ssl_channel_credentials=expected_ssl_channel_creds, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + ): + with mock.patch( + "google.auth.transport.grpc.SslCredentials.is_mtls", + new_callable=mock.PropertyMock, + ) as is_mtls_mock: + is_mtls_mock.return_value = False + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_specialist_pool_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = SpecialistPoolServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + ssl_channel_credentials=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_specialist_pool_from_dict(): + test_create_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_async_from_dict(): + await test_create_specialist_pool_async(request_type=dict) + + +def test_create_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.CreateSpecialistPoolRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.CreateSpecialistPoolRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_specialist_pool( + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + +def test_create_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_specialist_pool( + specialist_pool_service.CreateSpecialistPoolRequest(), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_specialist_pool( + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_specialist_pool( + specialist_pool_service.CreateSpecialistPoolRequest(), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + +def test_get_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + + response = client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, specialist_pool.SpecialistPool) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.specialist_managers_count == 2662 + + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + + +def test_get_specialist_pool_from_dict(): + test_get_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + ) + + response = await client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, specialist_pool.SpecialistPool) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.specialist_managers_count == 2662 + + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + + +@pytest.mark.asyncio +async def test_get_specialist_pool_async_from_dict(): + await test_get_specialist_pool_async(request_type=dict) + + +def test_get_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.GetSpecialistPoolRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = specialist_pool.SpecialistPool() + + client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.GetSpecialistPoolRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) + + await client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_specialist_pool( + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_specialist_pool( + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + ) + + +def test_list_specialist_pools( + transport: str = "grpc", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListSpecialistPoolsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_specialist_pools_from_dict(): + test_list_specialist_pools(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_from_dict(): + await test_list_specialist_pools_async(request_type=dict) + + +def test_list_specialist_pools_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.ListSpecialistPoolsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_specialist_pools_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.ListSpecialistPoolsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) + + await client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_specialist_pools_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_specialist_pools(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_specialist_pools_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_specialist_pools( + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_specialist_pools_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_specialist_pools(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_specialist_pools_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_specialist_pools( + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + ) + + +def test_list_specialist_pools_pager(): + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_specialist_pools(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) + + +def test_list_specialist_pools_pages(): + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + pages = list(client.list_specialist_pools(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_pager(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_specialist_pools(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_pages(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_specialist_pools(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_specialist_pool_from_dict(): + test_delete_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_async_from_dict(): + await test_delete_specialist_pool_async(request_type=dict) + + +def test_delete_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.DeleteSpecialistPoolRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.DeleteSpecialistPoolRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_specialist_pool( + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_specialist_pool( + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + ) + + +def test_update_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_specialist_pool_from_dict(): + test_update_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_async_from_dict(): + await test_update_specialist_pool_async(request_type=dict) + + +def test_update_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.UpdateSpecialistPoolRequest() + request.specialist_pool.name = "specialist_pool.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.UpdateSpecialistPoolRequest() + request.specialist_pool.name = "specialist_pool.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] + + +def test_update_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_specialist_pool( + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_specialist_pool( + specialist_pool_service.UpdateSpecialistPoolRequest(), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_specialist_pool( + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_specialist_pool( + specialist_pool_service.UpdateSpecialistPoolRequest(), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = SpecialistPoolServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) + + +def test_specialist_pool_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.SpecialistPoolServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_specialist_pool_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.SpecialistPoolServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_specialist_pool", + "get_specialist_pool", + "list_specialist_pools", + "delete_specialist_pool", + "update_specialist_pool", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_specialist_pool_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.SpecialistPoolServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_specialist_pool_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.SpecialistPoolServiceTransport() + adc.assert_called_once() + + +def test_specialist_pool_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + SpecialistPoolServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_specialist_pool_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_specialist_pool_service_host_no_port(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_specialist_pool_service_host_with_port(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_specialist_pool_service_grpc_transport_channel(): + channel = grpc.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_specialist_pool_service_grpc_asyncio_transport_channel(): + channel = aio.insecure_channel("http://localhost/") + + # Check that channel is used if provided. + transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel", autospec=True + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_specialist_pool_service_grpc_lro_client(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_specialist_pool_service_grpc_lro_async_client(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_specialist_pool_path(): + project = "squid" + location = "clam" + specialist_pool = "whelk" + + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) + actual = SpecialistPoolServiceClient.specialist_pool_path( + project, location, specialist_pool + ) + assert expected == actual + + +def test_parse_specialist_pool_path(): + expected = { + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", + } + path = SpecialistPoolServiceClient.specialist_pool_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = SpecialistPoolServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = SpecialistPoolServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = SpecialistPoolServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = SpecialistPoolServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = SpecialistPoolServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = SpecialistPoolServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = SpecialistPoolServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = SpecialistPoolServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = SpecialistPoolServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = SpecialistPoolServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/test_dataset_service.py b/tests/unit/gapic/test_dataset_service.py deleted file mode 100644 index bdf4410884..0000000000 --- a/tests/unit/gapic/test_dataset_service.py +++ /dev/null @@ -1,1140 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.dataset_service import ( - DatasetServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.dataset_service import pagers -from google.cloud.aiplatform_v1beta1.services.dataset_service import transports -from google.cloud.aiplatform_v1beta1.types import annotation -from google.cloud.aiplatform_v1beta1.types import annotation_spec -from google.cloud.aiplatform_v1beta1.types import data_item -from google.cloud.aiplatform_v1beta1.types import dataset -from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset -from google.cloud.aiplatform_v1beta1.types import dataset_service -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def test_dataset_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = DatasetServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = DatasetServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_dataset_service_client_client_options(): - # Check the default options have their expected values. - assert ( - DatasetServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.DatasetServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = DatasetServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_dataset_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.dataset_service.DatasetServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = DatasetServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_dataset(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.CreateDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_dataset( - parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].dataset == gca_dataset.Dataset(name="name_value") - - -def test_create_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_dataset( - dataset_service.CreateDatasetRequest(), - parent="parent_value", - dataset=gca_dataset.Dataset(name="name_value"), - ) - - -def test_get_dataset(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.GetDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - - response = client.get_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, dataset.Dataset) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == "etag_value" - - -def test_get_dataset_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetDatasetRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - call.return_value = dataset.Dataset() - client.get_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset.Dataset() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_dataset(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_dataset( - dataset_service.GetDatasetRequest(), name="name_value", - ) - - -def test_update_dataset(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.UpdateDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_dataset.Dataset( - name="name_value", - display_name="display_name_value", - metadata_schema_uri="metadata_schema_uri_value", - etag="etag_value", - ) - - response = client.update_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_dataset.Dataset) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.etag == "etag_value" - - -def test_update_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_dataset.Dataset() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.update_dataset( - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].dataset == gca_dataset.Dataset(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_dataset( - dataset_service.UpdateDatasetRequest(), - dataset=gca_dataset.Dataset(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_list_datasets(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListDatasetsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_datasets(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDatasetsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_datasets_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListDatasetsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - call.return_value = dataset_service.ListDatasetsResponse() - client.list_datasets(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_datasets_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDatasetsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_datasets(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_datasets_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_datasets( - dataset_service.ListDatasetsRequest(), parent="parent_value", - ) - - -def test_list_datasets_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], - ), - RuntimeError, - ) - results = [i for i in client.list_datasets(request={},)] - assert len(results) == 6 - assert all(isinstance(i, dataset.Dataset) for i in results) - - -def test_list_datasets_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_datasets), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], - next_page_token="abc", - ), - dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(),], next_page_token="ghi", - ), - dataset_service.ListDatasetsResponse( - datasets=[dataset.Dataset(), dataset.Dataset(),], - ), - RuntimeError, - ) - pages = list(client.list_datasets(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_dataset(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.DeleteDatasetRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_dataset(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_dataset_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_dataset), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_dataset(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_dataset_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_dataset( - dataset_service.DeleteDatasetRequest(), name="name_value", - ) - - -def test_import_data(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ImportDataRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.import_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.import_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_import_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.import_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.import_data( - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - assert args[0].import_configs == [ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ] - - -def test_import_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.import_data( - dataset_service.ImportDataRequest(), - name="name_value", - import_configs=[ - dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) - ], - ) - - -def test_export_data(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ExportDataRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.export_data(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_export_data_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_data), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.export_data( - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - assert args[0].export_config == dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ) - - -def test_export_data_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.export_data( - dataset_service.ExportDataRequest(), - name="name_value", - export_config=dataset.ExportDataConfig( - gcs_destination=io.GcsDestination( - output_uri_prefix="output_uri_prefix_value" - ) - ), - ) - - -def test_list_data_items(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListDataItemsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_data_items(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataItemsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_data_items_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListDataItemsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - call.return_value = dataset_service.ListDataItemsResponse() - client.list_data_items(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_data_items_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListDataItemsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_data_items(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_data_items_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_data_items( - dataset_service.ListDataItemsRequest(), parent="parent_value", - ) - - -def test_list_data_items_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], - ), - RuntimeError, - ) - results = [i for i in client.list_data_items(request={},)] - assert len(results) == 6 - assert all(isinstance(i, data_item.DataItem) for i in results) - - -def test_list_data_items_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_data_items), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListDataItemsResponse( - data_items=[ - data_item.DataItem(), - data_item.DataItem(), - data_item.DataItem(), - ], - next_page_token="abc", - ), - dataset_service.ListDataItemsResponse( - data_items=[], next_page_token="def", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(),], next_page_token="ghi", - ), - dataset_service.ListDataItemsResponse( - data_items=[data_item.DataItem(), data_item.DataItem(),], - ), - RuntimeError, - ) - pages = list(client.list_data_items(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_get_annotation_spec(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.GetAnnotationSpecRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec( - name="name_value", display_name="display_name_value", etag="etag_value", - ) - - response = client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, annotation_spec.AnnotationSpec) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.etag == "etag_value" - - -def test_get_annotation_spec_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.GetAnnotationSpecRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - call.return_value = annotation_spec.AnnotationSpec() - client.get_annotation_spec(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_annotation_spec_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_annotation_spec), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = annotation_spec.AnnotationSpec() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_annotation_spec(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_annotation_spec_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_annotation_spec( - dataset_service.GetAnnotationSpecRequest(), name="name_value", - ) - - -def test_list_annotations(transport: str = "grpc"): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = dataset_service.ListAnnotationsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListAnnotationsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_annotations_field_headers(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = dataset_service.ListAnnotationsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - call.return_value = dataset_service.ListAnnotationsResponse() - client.list_annotations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_annotations_flattened(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = dataset_service.ListAnnotationsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_annotations(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_annotations_flattened_error(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_annotations( - dataset_service.ListAnnotationsRequest(), parent="parent_value", - ) - - -def test_list_annotations_pager(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], - ), - RuntimeError, - ) - results = [i for i in client.list_annotations(request={},)] - assert len(results) == 6 - assert all(isinstance(i, annotation.Annotation) for i in results) - - -def test_list_annotations_pages(): - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_annotations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - dataset_service.ListAnnotationsResponse( - annotations=[ - annotation.Annotation(), - annotation.Annotation(), - annotation.Annotation(), - ], - next_page_token="abc", - ), - dataset_service.ListAnnotationsResponse( - annotations=[], next_page_token="def", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(),], next_page_token="ghi", - ), - dataset_service.ListAnnotationsResponse( - annotations=[annotation.Annotation(), annotation.Annotation(),], - ), - RuntimeError, - ) - pages = list(client.list_annotations(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.DatasetServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = DatasetServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.DatasetServiceGrpcTransport,) - - -def test_dataset_service_base_transport(): - # Instantiate the base transport. - transport = transports.DatasetServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_dataset", - "get_dataset", - "update_dataset", - "list_datasets", - "delete_dataset", - "import_data", - "export_data", - "list_data_items", - "get_annotation_spec", - "list_annotations", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_dataset_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - DatasetServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_dataset_service_host_no_port(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_dataset_service_host_with_port(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_dataset_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.DatasetServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_dataset_service_grpc_lro_client(): - client = DatasetServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_dataset_path(): - project = "squid" - location = "clam" - dataset = "whelk" - - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( - project=project, location=location, dataset=dataset, - ) - actual = DatasetServiceClient.dataset_path(project, location, dataset) - assert expected == actual diff --git a/tests/unit/gapic/test_endpoint_service.py b/tests/unit/gapic/test_endpoint_service.py deleted file mode 100644 index 4059cdb819..0000000000 --- a/tests/unit/gapic/test_endpoint_service.py +++ /dev/null @@ -1,771 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.endpoint_service import ( - EndpointServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers -from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports -from google.cloud.aiplatform_v1beta1.types import accelerator_type -from google.cloud.aiplatform_v1beta1.types import endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint -from google.cloud.aiplatform_v1beta1.types import endpoint_service -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def test_endpoint_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = EndpointServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = EndpointServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_endpoint_service_client_client_options(): - # Check the default options have their expected values. - assert ( - EndpointServiceClient.DEFAULT_OPTIONS.api_endpoint - == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.EndpointServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = EndpointServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_endpoint_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.endpoint_service.EndpointServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = EndpointServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_endpoint(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.CreateEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.create_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_endpoint( - parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - - -def test_create_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_endpoint( - endpoint_service.CreateEndpointRequest(), - parent="parent_value", - endpoint=gca_endpoint.Endpoint(name="name_value"), - ) - - -def test_get_endpoint(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.GetEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - - response = client.get_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, endpoint.Endpoint) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.etag == "etag_value" - - -def test_get_endpoint_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.GetEndpointRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - call.return_value = endpoint.Endpoint() - client.get_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint.Endpoint() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_endpoint(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_endpoint( - endpoint_service.GetEndpointRequest(), name="name_value", - ) - - -def test_list_endpoints(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.ListEndpointsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint_service.ListEndpointsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_endpoints(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListEndpointsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_endpoints_field_headers(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = endpoint_service.ListEndpointsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - call.return_value = endpoint_service.ListEndpointsResponse() - client.list_endpoints(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_endpoints_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = endpoint_service.ListEndpointsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_endpoints(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_endpoints_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_endpoints( - endpoint_service.ListEndpointsRequest(), parent="parent_value", - ) - - -def test_list_endpoints_pager(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - endpoint.Endpoint(), - ], - next_page_token="abc", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], - ), - RuntimeError, - ) - results = [i for i in client.list_endpoints(request={},)] - assert len(results) == 6 - assert all(isinstance(i, endpoint.Endpoint) for i in results) - - -def test_list_endpoints_pages(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_endpoints), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - endpoint_service.ListEndpointsResponse( - endpoints=[ - endpoint.Endpoint(), - endpoint.Endpoint(), - endpoint.Endpoint(), - ], - next_page_token="abc", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[], next_page_token="def", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(),], next_page_token="ghi", - ), - endpoint_service.ListEndpointsResponse( - endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], - ), - RuntimeError, - ) - pages = list(client.list_endpoints(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_update_endpoint(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.UpdateEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_endpoint.Endpoint( - name="name_value", - display_name="display_name_value", - description="description_value", - etag="etag_value", - ) - - response = client.update_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_endpoint.Endpoint) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.etag == "etag_value" - - -def test_update_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_endpoint.Endpoint() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.update_endpoint( - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_endpoint( - endpoint_service.UpdateEndpointRequest(), - endpoint=gca_endpoint.Endpoint(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_delete_endpoint(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.DeleteEndpointRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_endpoint(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_endpoint_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_endpoint), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_endpoint(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_endpoint_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_endpoint( - endpoint_service.DeleteEndpointRequest(), name="name_value", - ) - - -def test_deploy_model(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.DeployModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.deploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_deploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.deploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.deploy_model( - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model == gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ) - assert args[0].traffic_split == {"key_value": 541} - - -def test_deploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.deploy_model( - endpoint_service.DeployModelRequest(), - endpoint="endpoint_value", - deployed_model=gca_endpoint.DeployedModel( - dedicated_resources=machine_resources.DedicatedResources( - machine_spec=machine_resources.MachineSpec( - machine_type="machine_type_value" - ) - ) - ), - traffic_split={"key_value": 541}, - ) - - -def test_undeploy_model(transport: str = "grpc"): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = endpoint_service.UndeployModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.undeploy_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_undeploy_model_flattened(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.undeploy_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.undeploy_model( - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].deployed_model_id == "deployed_model_id_value" - assert args[0].traffic_split == {"key_value": 541} - - -def test_undeploy_model_flattened_error(): - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.undeploy_model( - endpoint_service.UndeployModelRequest(), - endpoint="endpoint_value", - deployed_model_id="deployed_model_id_value", - traffic_split={"key_value": 541}, - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.EndpointServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = EndpointServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.EndpointServiceGrpcTransport,) - - -def test_endpoint_service_base_transport(): - # Instantiate the base transport. - transport = transports.EndpointServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_endpoint", - "get_endpoint", - "list_endpoints", - "update_endpoint", - "delete_endpoint", - "deploy_model", - "undeploy_model", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_endpoint_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - EndpointServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_endpoint_service_host_no_port(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_endpoint_service_host_with_port(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_endpoint_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.EndpointServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_endpoint_service_grpc_lro_client(): - client = EndpointServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_endpoint_path(): - project = "squid" - location = "clam" - endpoint = "whelk" - - expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( - project=project, location=location, endpoint=endpoint, - ) - actual = EndpointServiceClient.endpoint_path(project, location, endpoint) - assert expected == actual diff --git a/tests/unit/gapic/test_job_service.py b/tests/unit/gapic/test_job_service.py deleted file mode 100644 index 92ed0d37e3..0000000000 --- a/tests/unit/gapic/test_job_service.py +++ /dev/null @@ -1,2118 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient -from google.cloud.aiplatform_v1beta1.services.job_service import pagers -from google.cloud.aiplatform_v1beta1.services.job_service import transports -from google.cloud.aiplatform_v1beta1.types import accelerator_type -from google.cloud.aiplatform_v1beta1.types import ( - accelerator_type as gca_accelerator_type, -) -from google.cloud.aiplatform_v1beta1.types import batch_prediction_job -from google.cloud.aiplatform_v1beta1.types import ( - batch_prediction_job as gca_batch_prediction_job, -) -from google.cloud.aiplatform_v1beta1.types import completion_stats -from google.cloud.aiplatform_v1beta1.types import ( - completion_stats as gca_completion_stats, -) -from google.cloud.aiplatform_v1beta1.types import custom_job -from google.cloud.aiplatform_v1beta1.types import custom_job as gca_custom_job -from google.cloud.aiplatform_v1beta1.types import data_labeling_job -from google.cloud.aiplatform_v1beta1.types import ( - data_labeling_job as gca_data_labeling_job, -) -from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job -from google.cloud.aiplatform_v1beta1.types import ( - hyperparameter_tuning_job as gca_hyperparameter_tuning_job, -) -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import job_service -from google.cloud.aiplatform_v1beta1.types import job_state -from google.cloud.aiplatform_v1beta1.types import machine_resources -from google.cloud.aiplatform_v1beta1.types import manual_batch_tuning_parameters -from google.cloud.aiplatform_v1beta1.types import ( - manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, -) -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import study -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import any_pb2 as any # type: ignore -from google.protobuf import duration_pb2 as duration # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.rpc import status_pb2 as status # type: ignore -from google.type import money_pb2 as money # type: ignore - - -def test_job_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = JobServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = JobServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_job_service_client_client_options(): - # Check the default options have their expected values. - assert JobServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = JobServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_job_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.job_service.JobServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_custom_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_custom_job.CustomJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_custom_job.CustomJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_custom_job( - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") - - -def test_create_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_custom_job( - job_service.CreateCustomJobRequest(), - parent="parent_value", - custom_job=gca_custom_job.CustomJob(name="name_value"), - ) - - -def test_get_custom_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = custom_job.CustomJob( - name="name_value", - display_name="display_name_value", - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, custom_job.CustomJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_custom_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetCustomJobRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - call.return_value = custom_job.CustomJob() - client.get_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_custom_job), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = custom_job.CustomJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_custom_job( - job_service.GetCustomJobRequest(), name="name_value", - ) - - -def test_list_custom_jobs(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListCustomJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListCustomJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_custom_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListCustomJobsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_custom_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListCustomJobsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - call.return_value = job_service.ListCustomJobsResponse() - client.list_custom_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_custom_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListCustomJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_custom_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_custom_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_custom_jobs( - job_service.ListCustomJobsRequest(), parent="parent_value", - ) - - -def test_list_custom_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - custom_job.CustomJob(), - ], - next_page_token="abc", - ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", - ), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], - ), - RuntimeError, - ) - results = [i for i in client.list_custom_jobs(request={},)] - assert len(results) == 6 - assert all(isinstance(i, custom_job.CustomJob) for i in results) - - -def test_list_custom_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_custom_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListCustomJobsResponse( - custom_jobs=[ - custom_job.CustomJob(), - custom_job.CustomJob(), - custom_job.CustomJob(), - ], - next_page_token="abc", - ), - job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", - ), - job_service.ListCustomJobsResponse( - custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], - ), - RuntimeError, - ) - pages = list(client.list_custom_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_custom_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_custom_job( - job_service.DeleteCustomJobRequest(), name="name_value", - ) - - -def test_cancel_custom_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelCustomJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_custom_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_custom_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_custom_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_custom_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_custom_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_custom_job( - job_service.CancelCustomJobRequest(), name="name_value", - ) - - -def test_create_data_labeling_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - - response = client.create_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_data_labeling_job.DataLabelingJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.datasets == ["datasets_value"] - assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == "inputs_schema_uri_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] - - -def test_create_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_data_labeling_job.DataLabelingJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_data_labeling_job( - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( - name="name_value" - ) - - -def test_create_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_data_labeling_job( - job_service.CreateDataLabelingJobRequest(), - parent="parent_value", - data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), - ) - - -def test_get_data_labeling_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = data_labeling_job.DataLabelingJob( - name="name_value", - display_name="display_name_value", - datasets=["datasets_value"], - labeler_count=1375, - instruction_uri="instruction_uri_value", - inputs_schema_uri="inputs_schema_uri_value", - state=job_state.JobState.JOB_STATE_QUEUED, - labeling_progress=1810, - specialist_pools=["specialist_pools_value"], - ) - - response = client.get_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, data_labeling_job.DataLabelingJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.datasets == ["datasets_value"] - assert response.labeler_count == 1375 - assert response.instruction_uri == "instruction_uri_value" - assert response.inputs_schema_uri == "inputs_schema_uri_value" - assert response.state == job_state.JobState.JOB_STATE_QUEUED - assert response.labeling_progress == 1810 - assert response.specialist_pools == ["specialist_pools_value"] - - -def test_get_data_labeling_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetDataLabelingJobRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - call.return_value = data_labeling_job.DataLabelingJob() - client.get_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = data_labeling_job.DataLabelingJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_data_labeling_job( - job_service.GetDataLabelingJobRequest(), name="name_value", - ) - - -def test_list_data_labeling_jobs(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListDataLabelingJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListDataLabelingJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_data_labeling_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListDataLabelingJobsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_data_labeling_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListDataLabelingJobsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - call.return_value = job_service.ListDataLabelingJobsResponse() - client.list_data_labeling_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_data_labeling_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListDataLabelingJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_data_labeling_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_data_labeling_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_data_labeling_jobs( - job_service.ListDataLabelingJobsRequest(), parent="parent_value", - ) - - -def test_list_data_labeling_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - next_page_token="abc", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_data_labeling_jobs(request={},)] - assert len(results) == 6 - assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) - - -def test_list_data_labeling_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_data_labeling_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - next_page_token="abc", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[], next_page_token="def", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], - next_page_token="ghi", - ), - job_service.ListDataLabelingJobsResponse( - data_labeling_jobs=[ - data_labeling_job.DataLabelingJob(), - data_labeling_job.DataLabelingJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_data_labeling_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_data_labeling_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_data_labeling_job( - job_service.DeleteDataLabelingJobRequest(), name="name_value", - ) - - -def test_cancel_data_labeling_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelDataLabelingJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_data_labeling_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_data_labeling_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_data_labeling_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_data_labeling_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_data_labeling_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_data_labeling_job( - job_service.CancelDataLabelingJobRequest(), name="name_value", - ) - - -def test_create_hyperparameter_tuning_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.max_trial_count == 1609 - assert response.parallel_trial_count == 2128 - assert response.max_failed_trial_count == 2317 - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_hyperparameter_tuning_job( - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[ - 0 - ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ) - - -def test_create_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_hyperparameter_tuning_job( - job_service.CreateHyperparameterTuningJobRequest(), - parent="parent_value", - hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value" - ), - ) - - -def test_get_hyperparameter_tuning_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( - name="name_value", - display_name="display_name_value", - max_trial_count=1609, - parallel_trial_count=2128, - max_failed_trial_count=2317, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.max_trial_count == 1609 - assert response.parallel_trial_count == 2128 - assert response.max_failed_trial_count == 2317 - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_hyperparameter_tuning_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetHyperparameterTuningJobRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - client.get_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_hyperparameter_tuning_job( - job_service.GetHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_list_hyperparameter_tuning_jobs(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListHyperparameterTuningJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListHyperparameterTuningJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_hyperparameter_tuning_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_hyperparameter_tuning_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListHyperparameterTuningJobsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - call.return_value = job_service.ListHyperparameterTuningJobsResponse() - client.list_hyperparameter_tuning_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_hyperparameter_tuning_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListHyperparameterTuningJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_hyperparameter_tuning_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_hyperparameter_tuning_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_hyperparameter_tuning_jobs( - job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", - ) - - -def test_list_hyperparameter_tuning_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="abc", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="ghi", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_hyperparameter_tuning_jobs(request={},)] - assert len(results) == 6 - assert all( - isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) - for i in results - ) - - -def test_list_hyperparameter_tuning_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_hyperparameter_tuning_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="abc", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[], next_page_token="def", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - next_page_token="ghi", - ), - job_service.ListHyperparameterTuningJobsResponse( - hyperparameter_tuning_jobs=[ - hyperparameter_tuning_job.HyperparameterTuningJob(), - hyperparameter_tuning_job.HyperparameterTuningJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_hyperparameter_tuning_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_hyperparameter_tuning_job( - job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_cancel_hyperparameter_tuning_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelHyperparameterTuningJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_hyperparameter_tuning_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_hyperparameter_tuning_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_hyperparameter_tuning_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_hyperparameter_tuning_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_hyperparameter_tuning_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_hyperparameter_tuning_job( - job_service.CancelHyperparameterTuningJobRequest(), name="name_value", - ) - - -def test_create_batch_prediction_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CreateBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.create_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.model == "model_value" - - assert response.generate_explanation is True - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_create_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_batch_prediction_job.BatchPredictionJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_batch_prediction_job( - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[ - 0 - ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ) - - -def test_create_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_batch_prediction_job( - job_service.CreateBatchPredictionJobRequest(), - parent="parent_value", - batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( - name="name_value" - ), - ) - - -def test_get_batch_prediction_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.GetBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = batch_prediction_job.BatchPredictionJob( - name="name_value", - display_name="display_name_value", - model="model_value", - generate_explanation=True, - state=job_state.JobState.JOB_STATE_QUEUED, - ) - - response = client.get_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, batch_prediction_job.BatchPredictionJob) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.model == "model_value" - - assert response.generate_explanation is True - assert response.state == job_state.JobState.JOB_STATE_QUEUED - - -def test_get_batch_prediction_job_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.GetBatchPredictionJobRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - call.return_value = batch_prediction_job.BatchPredictionJob() - client.get_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = batch_prediction_job.BatchPredictionJob() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_batch_prediction_job( - job_service.GetBatchPredictionJobRequest(), name="name_value", - ) - - -def test_list_batch_prediction_jobs(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.ListBatchPredictionJobsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListBatchPredictionJobsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_batch_prediction_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListBatchPredictionJobsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_batch_prediction_jobs_field_headers(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = job_service.ListBatchPredictionJobsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - call.return_value = job_service.ListBatchPredictionJobsResponse() - client.list_batch_prediction_jobs(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_batch_prediction_jobs_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = job_service.ListBatchPredictionJobsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_batch_prediction_jobs(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_batch_prediction_jobs_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_batch_prediction_jobs( - job_service.ListBatchPredictionJobsRequest(), parent="parent_value", - ) - - -def test_list_batch_prediction_jobs_pager(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_batch_prediction_jobs(request={},)] - assert len(results) == 6 - assert all( - isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results - ) - - -def test_list_batch_prediction_jobs_pages(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_batch_prediction_jobs), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - next_page_token="abc", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[], next_page_token="def", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], - next_page_token="ghi", - ), - job_service.ListBatchPredictionJobsResponse( - batch_prediction_jobs=[ - batch_prediction_job.BatchPredictionJob(), - batch_prediction_job.BatchPredictionJob(), - ], - ), - RuntimeError, - ) - pages = list(client.list_batch_prediction_jobs(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_batch_prediction_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.DeleteBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_batch_prediction_job( - job_service.DeleteBatchPredictionJobRequest(), name="name_value", - ) - - -def test_cancel_batch_prediction_job(transport: str = "grpc"): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = job_service.CancelBatchPredictionJobRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_batch_prediction_job(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_batch_prediction_job_flattened(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_batch_prediction_job), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_batch_prediction_job(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_batch_prediction_job_flattened_error(): - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_batch_prediction_job( - job_service.CancelBatchPredictionJobRequest(), name="name_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.JobServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = JobServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.JobServiceGrpcTransport,) - - -def test_job_service_base_transport(): - # Instantiate the base transport. - transport = transports.JobServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_custom_job", - "get_custom_job", - "list_custom_jobs", - "delete_custom_job", - "cancel_custom_job", - "create_data_labeling_job", - "get_data_labeling_job", - "list_data_labeling_jobs", - "delete_data_labeling_job", - "cancel_data_labeling_job", - "create_hyperparameter_tuning_job", - "get_hyperparameter_tuning_job", - "list_hyperparameter_tuning_jobs", - "delete_hyperparameter_tuning_job", - "cancel_hyperparameter_tuning_job", - "create_batch_prediction_job", - "get_batch_prediction_job", - "list_batch_prediction_jobs", - "delete_batch_prediction_job", - "cancel_batch_prediction_job", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_job_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - JobServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_job_service_host_no_port(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_job_service_host_with_port(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_job_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.JobServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_job_service_grpc_lro_client(): - client = JobServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_custom_job_path(): - project = "squid" - location = "clam" - custom_job = "whelk" - - expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( - project=project, location=location, custom_job=custom_job, - ) - actual = JobServiceClient.custom_job_path(project, location, custom_job) - assert expected == actual - - -def test_batch_prediction_job_path(): - project = "squid" - location = "clam" - batch_prediction_job = "whelk" - - expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( - project=project, location=location, batch_prediction_job=batch_prediction_job, - ) - actual = JobServiceClient.batch_prediction_job_path( - project, location, batch_prediction_job - ) - assert expected == actual - - -def test_hyperparameter_tuning_job_path(): - project = "squid" - location = "clam" - hyperparameter_tuning_job = "whelk" - - expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( - project=project, - location=location, - hyperparameter_tuning_job=hyperparameter_tuning_job, - ) - actual = JobServiceClient.hyperparameter_tuning_job_path( - project, location, hyperparameter_tuning_job - ) - assert expected == actual - - -def test_data_labeling_job_path(): - project = "squid" - location = "clam" - data_labeling_job = "whelk" - - expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( - project=project, location=location, data_labeling_job=data_labeling_job, - ) - actual = JobServiceClient.data_labeling_job_path( - project, location, data_labeling_job - ) - assert expected == actual diff --git a/tests/unit/gapic/test_model_service.py b/tests/unit/gapic/test_model_service.py deleted file mode 100644 index 854272f8a5..0000000000 --- a/tests/unit/gapic/test_model_service.py +++ /dev/null @@ -1,1223 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.model_service import ModelServiceClient -from google.cloud.aiplatform_v1beta1.services.model_service import pagers -from google.cloud.aiplatform_v1beta1.services.model_service import transports -from google.cloud.aiplatform_v1beta1.types import deployed_model_ref -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import model as gca_model -from google.cloud.aiplatform_v1beta1.types import model_evaluation -from google.cloud.aiplatform_v1beta1.types import model_evaluation_slice -from google.cloud.aiplatform_v1beta1.types import model_service -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore - - -def test_model_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = ModelServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = ModelServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_model_service_client_client_options(): - # Check the default options have their expected values. - assert ( - ModelServiceClient.DEFAULT_OPTIONS.api_endpoint == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.ModelServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = ModelServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_model_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.model_service.ModelServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_upload_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.UploadModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.upload_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.upload_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_upload_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.upload_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.upload_model( - parent="parent_value", model=gca_model.Model(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].model == gca_model.Model(name="name_value") - - -def test_upload_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.upload_model( - model_service.UploadModelRequest(), - parent="parent_value", - model=gca_model.Model(name="name_value"), - ) - - -def test_get_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", - ) - - response = client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model.Model) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [ - model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - assert response.etag == "etag_value" - - -def test_get_model_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - call.return_value = model.Model() - client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.get_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_model(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model( - model_service.GetModelRequest(), name="name_value", - ) - - -def test_list_models(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_models_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - call.return_value = model_service.ListModelsResponse() - client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_models_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_models(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_models_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_models( - model_service.ListModelsRequest(), parent="parent_value", - ) - - -def test_list_models_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", - ), - model_service.ListModelsResponse(models=[], next_page_token="def",), - model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", - ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), - RuntimeError, - ) - results = [i for i in client.list_models(request={},)] - assert len(results) == 6 - assert all(isinstance(i, model.Model) for i in results) - - -def test_list_models_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.list_models), "__call__") as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[model.Model(), model.Model(), model.Model(),], - next_page_token="abc", - ), - model_service.ListModelsResponse(models=[], next_page_token="def",), - model_service.ListModelsResponse( - models=[model.Model(),], next_page_token="ghi", - ), - model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), - RuntimeError, - ) - pages = list(client.list_models(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_update_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.UpdateModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_model.Model( - name="name_value", - display_name="display_name_value", - description="description_value", - metadata_schema_uri="metadata_schema_uri_value", - training_pipeline="training_pipeline_value", - artifact_uri="artifact_uri_value", - supported_deployment_resources_types=[ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ], - supported_input_storage_formats=["supported_input_storage_formats_value"], - supported_output_storage_formats=["supported_output_storage_formats_value"], - etag="etag_value", - ) - - response = client.update_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_model.Model) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.description == "description_value" - assert response.metadata_schema_uri == "metadata_schema_uri_value" - assert response.training_pipeline == "training_pipeline_value" - assert response.artifact_uri == "artifact_uri_value" - assert response.supported_deployment_resources_types == [ - gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES - ] - assert response.supported_input_storage_formats == [ - "supported_input_storage_formats_value" - ] - assert response.supported_output_storage_formats == [ - "supported_output_storage_formats_value" - ] - assert response.etag == "etag_value" - - -def test_update_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.update_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = gca_model.Model() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.update_model( - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].model == gca_model.Model(name="name_value") - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_model( - model_service.UpdateModelRequest(), - model=gca_model.Model(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_delete_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.DeleteModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.delete_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_model(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_model( - model_service.DeleteModelRequest(), name="name_value", - ) - - -def test_export_model(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ExportModelRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.export_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_export_model_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.export_model), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.export_model( - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ) - - -def test_export_model_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.export_model( - model_service.ExportModelRequest(), - name="name_value", - output_config=model_service.ExportModelRequest.OutputConfig( - export_format_id="export_format_id_value" - ), - ) - - -def test_get_model_evaluation(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelEvaluationRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation.ModelEvaluation( - name="name_value", - metrics_schema_uri="metrics_schema_uri_value", - slice_dimensions=["slice_dimensions_value"], - ) - - response = client.get_model_evaluation(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model_evaluation.ModelEvaluation) - assert response.name == "name_value" - assert response.metrics_schema_uri == "metrics_schema_uri_value" - assert response.slice_dimensions == ["slice_dimensions_value"] - - -def test_get_model_evaluation_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelEvaluationRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - call.return_value = model_evaluation.ModelEvaluation() - client.get_model_evaluation(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_evaluation_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation.ModelEvaluation() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_model_evaluation(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_model_evaluation_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model_evaluation( - model_service.GetModelEvaluationRequest(), name="name_value", - ) - - -def test_list_model_evaluations(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelEvaluationsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_model_evaluations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelEvaluationsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_model_evaluations_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelEvaluationsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - call.return_value = model_service.ListModelEvaluationsResponse() - client.list_model_evaluations(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_model_evaluations_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_model_evaluations(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_model_evaluations_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_model_evaluations( - model_service.ListModelEvaluationsRequest(), parent="parent_value", - ) - - -def test_list_model_evaluations_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_model_evaluations(request={},)] - assert len(results) == 6 - assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) - - -def test_list_model_evaluations_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluations), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[], next_page_token="def", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[model_evaluation.ModelEvaluation(),], - next_page_token="ghi", - ), - model_service.ListModelEvaluationsResponse( - model_evaluations=[ - model_evaluation.ModelEvaluation(), - model_evaluation.ModelEvaluation(), - ], - ), - RuntimeError, - ) - pages = list(client.list_model_evaluations(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_get_model_evaluation_slice(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.GetModelEvaluationSliceRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation_slice.ModelEvaluationSlice( - name="name_value", metrics_schema_uri="metrics_schema_uri_value", - ) - - response = client.get_model_evaluation_slice(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) - assert response.name == "name_value" - assert response.metrics_schema_uri == "metrics_schema_uri_value" - - -def test_get_model_evaluation_slice_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelEvaluationSliceRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - call.return_value = model_evaluation_slice.ModelEvaluationSlice() - client.get_model_evaluation_slice(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_model_evaluation_slice_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_model_evaluation_slice), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_evaluation_slice.ModelEvaluationSlice() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_model_evaluation_slice(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_model_evaluation_slice_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model_evaluation_slice( - model_service.GetModelEvaluationSliceRequest(), name="name_value", - ) - - -def test_list_model_evaluation_slices(transport: str = "grpc"): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = model_service.ListModelEvaluationSlicesRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationSlicesResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_model_evaluation_slices(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelEvaluationSlicesPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_model_evaluation_slices_field_headers(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.ListModelEvaluationSlicesRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - call.return_value = model_service.ListModelEvaluationSlicesResponse() - client.list_model_evaluation_slices(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_model_evaluation_slices_flattened(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelEvaluationSlicesResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_model_evaluation_slices(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_model_evaluation_slices_flattened_error(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_model_evaluation_slices( - model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", - ) - - -def test_list_model_evaluation_slices_pager(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="ghi", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_model_evaluation_slices(request={},)] - assert len(results) == 6 - assert all( - isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results - ) - - -def test_list_model_evaluation_slices_pages(): - client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_model_evaluation_slices), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="abc", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[], next_page_token="def", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - ], - next_page_token="ghi", - ), - model_service.ListModelEvaluationSlicesResponse( - model_evaluation_slices=[ - model_evaluation_slice.ModelEvaluationSlice(), - model_evaluation_slice.ModelEvaluationSlice(), - ], - ), - RuntimeError, - ) - pages = list(client.list_model_evaluation_slices(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = ModelServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.ModelServiceGrpcTransport,) - - -def test_model_service_base_transport(): - # Instantiate the base transport. - transport = transports.ModelServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "upload_model", - "get_model", - "list_models", - "update_model", - "delete_model", - "export_model", - "get_model_evaluation", - "list_model_evaluations", - "get_model_evaluation_slice", - "list_model_evaluation_slices", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_model_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - ModelServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_model_service_host_no_port(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_model_service_host_with_port(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_model_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.ModelServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_model_service_grpc_lro_client(): - client = ModelServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_model_path(): - project = "squid" - location = "clam" - model = "whelk" - - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) - actual = ModelServiceClient.model_path(project, location, model) - assert expected == actual diff --git a/tests/unit/gapic/test_pipeline_service.py b/tests/unit/gapic/test_pipeline_service.py deleted file mode 100644 index c7c2db4449..0000000000 --- a/tests/unit/gapic/test_pipeline_service.py +++ /dev/null @@ -1,675 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.pipeline_service import ( - PipelineServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers -from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports -from google.cloud.aiplatform_v1beta1.types import deployed_model_ref -from google.cloud.aiplatform_v1beta1.types import env_var -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import explanation_metadata -from google.cloud.aiplatform_v1beta1.types import io -from google.cloud.aiplatform_v1beta1.types import model -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import pipeline_service -from google.cloud.aiplatform_v1beta1.types import pipeline_state -from google.cloud.aiplatform_v1beta1.types import training_pipeline -from google.cloud.aiplatform_v1beta1.types import ( - training_pipeline as gca_training_pipeline, -) -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import any_pb2 as any # type: ignore -from google.protobuf import field_mask_pb2 as field_mask # type: ignore -from google.protobuf import struct_pb2 as struct # type: ignore -from google.protobuf import timestamp_pb2 as timestamp # type: ignore -from google.rpc import status_pb2 as status # type: ignore - - -def test_pipeline_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = PipelineServiceClient.from_service_account_file("dummy/file/path.json") - assert client._transport._credentials == creds - - client = PipelineServiceClient.from_service_account_json("dummy/file/path.json") - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_pipeline_service_client_client_options(): - # Check the default options have their expected values. - assert ( - PipelineServiceClient.DEFAULT_OPTIONS.api_endpoint - == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.PipelineServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = PipelineServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_pipeline_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.pipeline_service.PipelineServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = PipelineServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_training_pipeline(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.CreateTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - - response = client.create_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, gca_training_pipeline.TrainingPipeline) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.training_task_definition == "training_task_definition_value" - assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED - - -def test_create_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = gca_training_pipeline.TrainingPipeline() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_training_pipeline( - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( - name="name_value" - ) - - -def test_create_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_training_pipeline( - pipeline_service.CreateTrainingPipelineRequest(), - parent="parent_value", - training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), - ) - - -def test_get_training_pipeline(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.GetTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = training_pipeline.TrainingPipeline( - name="name_value", - display_name="display_name_value", - training_task_definition="training_task_definition_value", - state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, - ) - - response = client.get_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, training_pipeline.TrainingPipeline) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.training_task_definition == "training_task_definition_value" - assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED - - -def test_get_training_pipeline_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.GetTrainingPipelineRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - call.return_value = training_pipeline.TrainingPipeline() - client.get_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = training_pipeline.TrainingPipeline() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_training_pipeline( - pipeline_service.GetTrainingPipelineRequest(), name="name_value", - ) - - -def test_list_training_pipelines(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.ListTrainingPipelinesRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = pipeline_service.ListTrainingPipelinesResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_training_pipelines(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListTrainingPipelinesPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_training_pipelines_field_headers(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = pipeline_service.ListTrainingPipelinesRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - call.return_value = pipeline_service.ListTrainingPipelinesResponse() - client.list_training_pipelines(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_training_pipelines_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = pipeline_service.ListTrainingPipelinesResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_training_pipelines(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_training_pipelines_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_training_pipelines( - pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", - ) - - -def test_list_training_pipelines_pager(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - next_page_token="abc", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_training_pipelines(request={},)] - assert len(results) == 6 - assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) - - -def test_list_training_pipelines_pages(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_training_pipelines), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - next_page_token="abc", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[], next_page_token="def", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[training_pipeline.TrainingPipeline(),], - next_page_token="ghi", - ), - pipeline_service.ListTrainingPipelinesResponse( - training_pipelines=[ - training_pipeline.TrainingPipeline(), - training_pipeline.TrainingPipeline(), - ], - ), - RuntimeError, - ) - pages = list(client.list_training_pipelines(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_training_pipeline(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.DeleteTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_training_pipeline( - pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", - ) - - -def test_cancel_training_pipeline(transport: str = "grpc"): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = pipeline_service.CancelTrainingPipelineRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - response = client.cancel_training_pipeline(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert response is None - - -def test_cancel_training_pipeline_flattened(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.cancel_training_pipeline), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = None - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.cancel_training_pipeline(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_cancel_training_pipeline_flattened_error(): - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.cancel_training_pipeline( - pipeline_service.CancelTrainingPipelineRequest(), name="name_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.PipelineServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = PipelineServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.PipelineServiceGrpcTransport,) - - -def test_pipeline_service_base_transport(): - # Instantiate the base transport. - transport = transports.PipelineServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_training_pipeline", - "get_training_pipeline", - "list_training_pipelines", - "delete_training_pipeline", - "cancel_training_pipeline", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_pipeline_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - PipelineServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_pipeline_service_host_no_port(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_pipeline_service_host_with_port(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_pipeline_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.PipelineServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_pipeline_service_grpc_lro_client(): - client = PipelineServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_model_path(): - project = "squid" - location = "clam" - model = "whelk" - - expected = "projects/{project}/locations/{location}/models/{model}".format( - project=project, location=location, model=model, - ) - actual = PipelineServiceClient.model_path(project, location, model) - assert expected == actual - - -def test_training_pipeline_path(): - project = "squid" - location = "clam" - training_pipeline = "whelk" - - expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( - project=project, location=location, training_pipeline=training_pipeline, - ) - actual = PipelineServiceClient.training_pipeline_path( - project, location, training_pipeline - ) - assert expected == actual diff --git a/tests/unit/gapic/test_prediction_service.py b/tests/unit/gapic/test_prediction_service.py deleted file mode 100644 index b23ab9b529..0000000000 --- a/tests/unit/gapic/test_prediction_service.py +++ /dev/null @@ -1,311 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.prediction_service import ( - PredictionServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.prediction_service import transports -from google.cloud.aiplatform_v1beta1.types import explanation -from google.cloud.aiplatform_v1beta1.types import prediction_service -from google.oauth2 import service_account -from google.protobuf import struct_pb2 as struct # type: ignore - - -def test_prediction_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = PredictionServiceClient.from_service_account_file( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - client = PredictionServiceClient.from_service_account_json( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_prediction_service_client_client_options(): - # Check the default options have their expected values. - assert ( - PredictionServiceClient.DEFAULT_OPTIONS.api_endpoint - == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.prediction_service.PredictionServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = PredictionServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_prediction_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.prediction_service.PredictionServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = PredictionServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_predict(transport: str = "grpc"): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = prediction_service.PredictRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse( - deployed_model_id="deployed_model_id_value", - ) - - response = client.predict(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.PredictResponse) - assert response.deployed_model_id == "deployed_model_id_value" - - -def test_predict_flattened(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.predict), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.PredictResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.predict( - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct.Value(null_value=struct.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct.Value( - # null_value=struct.NullValue.NULL_VALUE - # ) - - -def test_predict_flattened_error(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.predict( - prediction_service.PredictRequest(), - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - ) - - -def test_explain(transport: str = "grpc"): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = prediction_service.ExplainRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse( - deployed_model_id="deployed_model_id_value", - ) - - response = client.explain(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, prediction_service.ExplainResponse) - assert response.deployed_model_id == "deployed_model_id_value" - - -def test_explain_flattened(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object(type(client._transport.explain), "__call__") as call: - # Designate an appropriate return value for the call. - call.return_value = prediction_service.ExplainResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.explain( - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].endpoint == "endpoint_value" - assert args[0].instances == [ - struct.Value(null_value=struct.NullValue.NULL_VALUE) - ] - # https://github.com/googleapis/gapic-generator-python/issues/414 - # assert args[0].parameters == struct.Value( - # null_value=struct.NullValue.NULL_VALUE - # ) - assert args[0].deployed_model_id == "deployed_model_id_value" - - -def test_explain_flattened_error(): - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.explain( - prediction_service.ExplainRequest(), - endpoint="endpoint_value", - instances=[struct.Value(null_value=struct.NullValue.NULL_VALUE)], - parameters=struct.Value(null_value=struct.NullValue.NULL_VALUE), - deployed_model_id="deployed_model_id_value", - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.PredictionServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = PredictionServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),) - assert isinstance(client._transport, transports.PredictionServiceGrpcTransport,) - - -def test_prediction_service_base_transport(): - # Instantiate the base transport. - transport = transports.PredictionServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "predict", - "explain", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - -def test_prediction_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - PredictionServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_prediction_service_host_no_port(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_prediction_service_host_with_port(): - client = PredictionServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_prediction_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.PredictionServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel diff --git a/tests/unit/gapic/test_specialist_pool_service.py b/tests/unit/gapic/test_specialist_pool_service.py deleted file mode 100644 index f8467edf33..0000000000 --- a/tests/unit/gapic/test_specialist_pool_service.py +++ /dev/null @@ -1,681 +0,0 @@ -# -*- coding: utf-8 -*- - -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from unittest import mock - -import grpc -import math -import pytest - -from google import auth -from google.api_core import client_options -from google.api_core import future -from google.api_core import operations_v1 -from google.auth import credentials -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import ( - SpecialistPoolServiceClient, -) -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import pagers -from google.cloud.aiplatform_v1beta1.services.specialist_pool_service import transports -from google.cloud.aiplatform_v1beta1.types import operation as gca_operation -from google.cloud.aiplatform_v1beta1.types import specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool as gca_specialist_pool -from google.cloud.aiplatform_v1beta1.types import specialist_pool_service -from google.longrunning import operations_pb2 -from google.oauth2 import service_account -from google.protobuf import field_mask_pb2 as field_mask # type: ignore - - -def test_specialist_pool_service_client_from_service_account_file(): - creds = credentials.AnonymousCredentials() - with mock.patch.object( - service_account.Credentials, "from_service_account_file" - ) as factory: - factory.return_value = creds - client = SpecialistPoolServiceClient.from_service_account_file( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - client = SpecialistPoolServiceClient.from_service_account_json( - "dummy/file/path.json" - ) - assert client._transport._credentials == creds - - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_specialist_pool_service_client_client_options(): - # Check the default options have their expected values. - assert ( - SpecialistPoolServiceClient.DEFAULT_OPTIONS.api_endpoint - == "aiplatform.googleapis.com" - ) - - # Check that options can be customized. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.SpecialistPoolServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = SpecialistPoolServiceClient(client_options=options) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_specialist_pool_service_client_client_options_from_dict(): - with mock.patch( - "google.cloud.aiplatform_v1beta1.services.specialist_pool_service.SpecialistPoolServiceClient.get_transport_class" - ) as gtc: - transport = gtc.return_value = mock.MagicMock() - client = SpecialistPoolServiceClient( - client_options={"api_endpoint": "squid.clam.whelk"} - ) - transport.assert_called_once_with(credentials=None, host="squid.clam.whelk") - - -def test_create_specialist_pool(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.CreateSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.create_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_create_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.create_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.create_specialist_pool( - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) - - -def test_create_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.create_specialist_pool( - specialist_pool_service.CreateSpecialistPoolRequest(), - parent="parent_value", - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - ) - - -def test_get_specialist_pool(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.GetSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool.SpecialistPool( - name="name_value", - display_name="display_name_value", - specialist_managers_count=2662, - specialist_manager_emails=["specialist_manager_emails_value"], - pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], - ) - - response = client.get_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, specialist_pool.SpecialistPool) - assert response.name == "name_value" - assert response.display_name == "display_name_value" - assert response.specialist_managers_count == 2662 - assert response.specialist_manager_emails == ["specialist_manager_emails_value"] - assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] - - -def test_get_specialist_pool_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.GetSpecialistPoolRequest(name="name/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - call.return_value = specialist_pool.SpecialistPool() - client.get_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] - - -def test_get_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.get_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool.SpecialistPool() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.get_specialist_pool(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_get_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_specialist_pool( - specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", - ) - - -def test_list_specialist_pools(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.ListSpecialistPoolsRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( - next_page_token="next_page_token_value", - ) - - response = client.list_specialist_pools(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListSpecialistPoolsPager) - assert response.next_page_token == "next_page_token_value" - - -def test_list_specialist_pools_field_headers(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = specialist_pool_service.ListSpecialistPoolsRequest(parent="parent/value",) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - client.list_specialist_pools(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] - - -def test_list_specialist_pools_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.list_specialist_pools(parent="parent_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].parent == "parent_value" - - -def test_list_specialist_pools_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_specialist_pools( - specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", - ) - - -def test_list_specialist_pools_pager(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - next_page_token="abc", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - ), - RuntimeError, - ) - results = [i for i in client.list_specialist_pools(request={},)] - assert len(results) == 6 - assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) - - -def test_list_specialist_pools_pages(): - client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.list_specialist_pools), "__call__" - ) as call: - # Set the response to a series of pages. - call.side_effect = ( - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - next_page_token="abc", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[], next_page_token="def", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[specialist_pool.SpecialistPool(),], - next_page_token="ghi", - ), - specialist_pool_service.ListSpecialistPoolsResponse( - specialist_pools=[ - specialist_pool.SpecialistPool(), - specialist_pool.SpecialistPool(), - ], - ), - RuntimeError, - ) - pages = list(client.list_specialist_pools(request={}).pages) - for page, token in zip(pages, ["abc", "def", "ghi", ""]): - assert page.raw_page.next_page_token == token - - -def test_delete_specialist_pool(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.DeleteSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.delete_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_delete_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.delete_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.delete_specialist_pool(name="name_value",) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].name == "name_value" - - -def test_delete_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.delete_specialist_pool( - specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", - ) - - -def test_update_specialist_pool(transport: str = "grpc"): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = specialist_pool_service.UpdateSpecialistPoolRequest() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.update_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/spam") - - response = client.update_specialist_pool(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - - assert args[0] == request - - # Establish that the response is the type that we expect. - assert isinstance(response, future.Future) - - -def test_update_specialist_pool_flattened(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client._transport.update_specialist_pool), "__call__" - ) as call: - # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name="operations/op") - - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = client.update_specialist_pool( - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( - name="name_value" - ) - assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) - - -def test_update_specialist_pool_flattened_error(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.update_specialist_pool( - specialist_pool_service.UpdateSpecialistPoolRequest(), - specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), - update_mask=field_mask.FieldMask(paths=["paths_value"]), - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.SpecialistPoolServiceGrpcTransport( - credentials=credentials.AnonymousCredentials(), - ) - client = SpecialistPoolServiceClient(transport=transport) - assert client._transport is transport - - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - ) - assert isinstance(client._transport, transports.SpecialistPoolServiceGrpcTransport,) - - -def test_specialist_pool_service_base_transport(): - # Instantiate the base transport. - transport = transports.SpecialistPoolServiceTransport( - credentials=credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - "create_specialist_pool", - "get_specialist_pool", - "list_specialist_pools", - "delete_specialist_pool", - "update_specialist_pool", - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - # Additionally, the LRO client (a property) should - # also raise NotImplementedError - with pytest.raises(NotImplementedError): - transport.operations_client - - -def test_specialist_pool_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(auth, "default") as adc: - adc.return_value = (credentials.AnonymousCredentials(), None) - SpecialistPoolServiceClient() - adc.assert_called_once_with( - scopes=("https://www.googleapis.com/auth/cloud-platform",) - ) - - -def test_specialist_pool_service_host_no_port(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:443" - - -def test_specialist_pool_service_host_with_port(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions( - api_endpoint="aiplatform.googleapis.com:8000" - ), - transport="grpc", - ) - assert client._transport._host == "aiplatform.googleapis.com:8000" - - -def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") - transport = transports.SpecialistPoolServiceGrpcTransport(channel=channel,) - assert transport.grpc_channel is channel - - -def test_specialist_pool_service_grpc_lro_client(): - client = SpecialistPoolServiceClient( - credentials=credentials.AnonymousCredentials(), transport="grpc", - ) - transport = client._transport - - # Ensure that we have a api-core operations client. - assert isinstance(transport.operations_client, operations_v1.OperationsClient,) - - # Ensure that subsequent calls to the property send the exact same object. - assert transport.operations_client is transport.operations_client - - -def test_specialist_pool_path(): - project = "squid" - location = "clam" - specialist_pool = "whelk" - - expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( - project=project, location=location, specialist_pool=specialist_pool, - ) - actual = SpecialistPoolServiceClient.specialist_pool_path( - project, location, specialist_pool - ) - assert expected == actual